Skip to content

[MLIR] Add ODS support for generating helpers for dialect (discardable) attributes #77024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 20, 2024

Conversation

joker-eph
Copy link
Collaborator

@joker-eph joker-eph commented Jan 4, 2024

This is a new ODS feature that allows dialects to define a list of
key/value pair representing an attribute type and a name.
This will generate helper classes on the dialect to be able to
manage discardable attributes on operations in a type safe way.

For example the test dialect can define:

  let discardableAttrs = (ins
     "mlir::IntegerAttr":$discardable_attr_key,
  );

And the following will be generated in the TestDialect class:

   /// Helper to manage the discardable attribute `discardable_attr_key`.
    class DiscardableAttrKeyAttrHelper {
      ::mlir::StringAttr name;
    public:
      static constexpr ::llvm::StringLiteral getNameStr() {
        return "test.discardable_attr_key";
      }
      constexpr ::mlir::StringAttr getName() {
        return name;
      }

      DiscardableAttrKeyAttrHelper(::mlir::MLIRContext *ctx)
        : name(::mlir::StringAttr::get(ctx, getNameStr())) {}

     mlir::IntegerAttr getAttr(::mlir::Operation *op) {
       return op->getAttrOfType<mlir::IntegerAttr>(name);
     }
     void setAttr(::mlir::Operation *op, mlir::IntegerAttr val) {
       op->setAttr(name, val);
     }
     bool isAttrPresent(::mlir::Operation *op) {
       return op->hasAttrOfType<mlir::IntegerAttr>(name);
     }
     void removeAttr(::mlir::Operation *op) {
       assert(op->hasAttrOfType<mlir::IntegerAttr>(name));
       op->removeAttr(name);
     }
   };
   DiscardableAttrKeyAttrHelper getDiscardableAttrKeyAttrHelper() {
     return discardableAttrKeyAttrName;
   }

User code having an instance of the TestDialect can then manipulate this
attribute on operation using:

  auto helper = testDialect.getDiscardableAttrKeyAttrHelper();

  helper.setAttr(op, value);
  helper.isAttrPresent(op);
  ...

@llvmbot
Copy link
Member

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-llvm

Author: Mehdi Amini (joker-eph)

Changes

WIP (missing docs)

See #75118 for context


Full diff: https://github.com/llvm/llvm-project/pull/77024.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+8-10)
  • (modified) mlir/include/mlir/IR/DialectBase.td (+2)
  • (modified) mlir/include/mlir/TableGen/Dialect.h (+6)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+9-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+2-2)
  • (modified) mlir/lib/TableGen/Dialect.cpp (+4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp (+6-4)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.td (+4)
  • (modified) mlir/tools/mlir-tblgen/DialectGen.cpp (+85-4)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..6abcfdf60e0fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
-    static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.flat_work_group_size");
-    }
-    static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
-    }
-
     /// The address space value that represents global memory.
     static constexpr unsigned kGlobalMemoryAddressSpace = 1;
     /// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
     static constexpr unsigned kPrivateMemoryAddressSpace = 5;
   }];
 
+  let discardableAttrs = (ins
+     "::mlir::UnitAttr":$kernel,
+     "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+     "::mlir::StringAttr":$flat_work_group_size,
+     "::mlir::IntegerAttr":$max_flat_work_group_size,
+     "::mlir::IntegerAttr":$waves_per_eu
+  );
+
   let useDefaultAttributePrinterParser = 1;
 }
 
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..16750dc7e4d320 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,8 @@ class Dialect {
   // pattern or interfaces.
   list<string> dependentDialects = [];
 
+  dag discardableAttrs = (ins);
+
   // The C++ namespace that ops of this dialect should be placed into.
   //
   // By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
 #define MLIR_TABLEGEN_DIALECT_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
 #include <string>
 #include <vector>
 
@@ -90,6 +92,10 @@ class Dialect {
   /// dialect.
   bool usePropertiesForAttributes() const;
 
+  llvm::DagInit *getDiscardableAttributes() const;
+
+  const llvm::Record *getDef() const { return def; }
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..abd2733ba4fbd1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
-
+    auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+    auto reqdWorkGroupSizeAttrHelper =
+        rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+    auto flatWorkGroupSizeAttrHelper =
+        rocdlDialect->getFlatWorkGroupSizeAttrHelper();
     // Manually rewrite known block size attributes so the LLVMIR translation
     // infrastructure can pick them up.
-    m.walk([ctx](LLVM::LLVMFuncOp op) {
+    m.walk([&](LLVM::LLVMFuncOp op) {
       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
-        op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
-                    blockSizes);
+        reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
         // Also set up the rocdl.flat_work_group_size attribute to prevent
         // conflicting metadata.
         uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
         }
         StringAttr flatSizeAttr =
             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
-        op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
-                    flatSizeAttr);
+        flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
       }
     });
   }
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
       converter,
       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
-      StringAttr::get(&converter.getContext(),
-                      ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+      ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
   if (Runtime::HIP == runtime) {
     patterns.add<GPUPrintfOpToHIPLowering>(converter);
   } else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
                                                      NamedAttribute attr) {
   // Kernel function attribute should be attached to functions.
-  if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+  if (kernelAttrName.getName() == attr.getName()) {
     if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+      return op->emitError() << "'" << kernelAttrName.getName()
                              << "' attribute attached to unexpected op";
     }
   }
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
   return def->getValueAsBit("usePropertiesForAttributes");
 }
 
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+  return def->getValueAsDag("discardableAttrs");
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 55a6285ec87eb4..6783ffcde6d531 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
                  NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+    auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+    if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
     // Override flat-work-group-size
     // TODO: update clients to rocdl.flat_work_group_size instead,
     // then remove this half of the branch
-    if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+    if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+        attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
       attrValueStream << "1," << value.getInt();
       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
     }
-    if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+    if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     }
 
     // Set reqd_work_group_size metadata
-    if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+    if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let isExtensible = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
+  let discardableAttrs = (ins
+     "mlir::IntegerAttr":$discardable_attr_key,
+     "SimpleAAttr":$other_discardable_attr_key
+  );
 
   let extraClassDeclaration = [{
     void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..a0ebd9ce8f29e3 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
                           std::function<bool(const llvm::Record *)>>;
 } // namespace
 
+static void populateDiscardableAttributes(
+    Dialect &dialect, llvm::DagInit *discardableAttrDag,
+    SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+  for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+    llvm::Init *arg = discardableAttrDag->getArg(i);
+
+    StringRef givenName = discardableAttrDag->getArgNameStr(i);
+    if (givenName.empty())
+      PrintFatalError(dialect.getDef()->getLoc(),
+                      "discardable attributes must be named");
+    discardableAttributes.push_back(
+        {givenName.str(), arg->getAsUnquotedString()});
+  }
+}
+
 /// Given a set of records for a T, filter the ones that correspond to
 /// the given dialect.
 template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
                                       mlir::OperationName opName) override;
 )";
 
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+    /// Helper to manage the discardable attribute `{1}`.
+    class {0}AttrHelper {{
+      mlir::StringAttr name;
+    public:
+      static constexpr llvm::StringLiteral getNameStr() {{
+        return "{4}.{1}";
+      }
+      constexpr mlir::StringAttr getName() {{
+        return name;
+      }
+
+      {0}AttrHelper(mlir::MLIRContext *ctx)
+        : name(mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+     {2} getAttr(::mlir::Operation *op) {{
+       return op->getAttrOfType<{2}>(getName());
+     }
+     void setAttr(::mlir::Operation *op, {2} val) {{
+       op->setAttr(getName(), val);
+     }
+   };
+   {0}AttrHelper get{0}AttrHelper() {
+     return {3}AttrName;
+   }
+ private:
+   {0}AttrHelper {3}AttrName;
+ public:
+)";
+
 /// Generate the declaration for the given dialect class.
 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
   // Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
       os << regionResultAttrVerifierDecl;
     if (dialect.hasOperationInterfaceFallback())
       os << operationInterfaceFallbackDecl;
+
+    llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+    SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+    populateDiscardableAttributes(dialect, discardableAttrDag,
+                                  discardableAttributes);
+
+    for (const auto &attrPair : discardableAttributes) {
+      std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/true);
+      std::string camelName = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/false);
+      os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+                          attrPair.first, attrPair.second, camelName,
+                          dialect.getName());
+    }
+
     if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
       os << *extraDecl;
 
@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
 /// {1}: initialization code that is emitted in the ctor body before calling
 ///      initialize().
 /// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
 static const char *const dialectConstructorStr = R"(
 {0}::{0}(::mlir::MLIRContext *context)
-    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+    {3}
+     {{
   {1}
   initialize();
 }
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
 
 )";
 
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+                           const llvm::RecordKeeper &recordKeeper,
+                           raw_ostream &os) {
   std::string cppClassName = dialect.getCppClassName();
 
   // Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
   // Emit the constructor and destructor.
   StringRef superClassName =
       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+  llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+  SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+  populateDiscardableAttributes(dialect, discardableAttrDag,
+                                discardableAttributes);
+  std::string discardableAttributesInit;
+  for (const auto &attrPair : discardableAttributes) {
+    std::string camelName = llvm::convertToCamelFromSnakeCase(
+        attrPair.first, /*capitalizeFirst=*/false);
+    llvm::raw_string_ostream os(discardableAttributesInit);
+    os << ", " << camelName << "AttrName(context)";
+  }
+
   os << llvm::formatv(dialectConstructorStr, cppClassName,
-                      dependentDialectRegistrations, superClassName);
+                      dependentDialectRegistrations, superClassName,
+                      discardableAttributesInit);
   if (!dialect.hasNonDefaultDestructor())
     os << llvm::formatv(dialectDestructorStr, cppClassName);
 }
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
   std::optional<Dialect> dialect = findDialectToGenerate(dialects);
   if (!dialect)
     return true;
-  emitDialectDef(*dialect, os);
+  emitDialectDef(*dialect, recordKeeper, os);
   return false;
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jan 4, 2024

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

WIP (missing docs)

See #75118 for context


Full diff: https://github.com/llvm/llvm-project/pull/77024.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+8-10)
  • (modified) mlir/include/mlir/IR/DialectBase.td (+2)
  • (modified) mlir/include/mlir/TableGen/Dialect.h (+6)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+9-8)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (+2-2)
  • (modified) mlir/lib/TableGen/Dialect.cpp (+4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp (+6-4)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.td (+4)
  • (modified) mlir/tools/mlir-tblgen/DialectGen.cpp (+85-4)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 48b830ae34f292..6abcfdf60e0fd0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -28,16 +28,6 @@ def ROCDL_Dialect : Dialect {
   let hasOperationAttrVerify = 1;
 
   let extraClassDeclaration = [{
-    /// Get the name of the attribute used to annotate external kernel
-    /// functions.
-    static StringRef getKernelFuncAttrName() { return "rocdl.kernel"; }
-    static constexpr ::llvm::StringLiteral getFlatWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.flat_work_group_size");
-    }
-    static constexpr ::llvm::StringLiteral getReqdWorkGroupSizeAttrName() {
-      return ::llvm::StringLiteral("rocdl.reqd_work_group_size");
-    }
-
     /// The address space value that represents global memory.
     static constexpr unsigned kGlobalMemoryAddressSpace = 1;
     /// The address space value that represents shared memory.
@@ -46,6 +36,14 @@ def ROCDL_Dialect : Dialect {
     static constexpr unsigned kPrivateMemoryAddressSpace = 5;
   }];
 
+  let discardableAttrs = (ins
+     "::mlir::UnitAttr":$kernel,
+     "::mlir::DenseI32ArrayAttr":$reqd_work_group_size,
+     "::mlir::StringAttr":$flat_work_group_size,
+     "::mlir::IntegerAttr":$max_flat_work_group_size,
+     "::mlir::IntegerAttr":$waves_per_eu
+  );
+
   let useDefaultAttributePrinterParser = 1;
 }
 
diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td
index 5afa23933ea1f7..16750dc7e4d320 100644
--- a/mlir/include/mlir/IR/DialectBase.td
+++ b/mlir/include/mlir/IR/DialectBase.td
@@ -34,6 +34,8 @@ class Dialect {
   // pattern or interfaces.
   list<string> dependentDialects = [];
 
+  dag discardableAttrs = (ins);
+
   // The C++ namespace that ops of this dialect should be placed into.
   //
   // By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 5337bd3beb5f9d..3530d240c976c6 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,6 +14,8 @@
 #define MLIR_TABLEGEN_DIALECT_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/TableGen/Record.h"
+
 #include <string>
 #include <vector>
 
@@ -90,6 +92,10 @@ class Dialect {
   /// dialect.
   bool usePropertiesForAttributes() const;
 
+  llvm::DagInit *getDiscardableAttributes() const;
+
+  const llvm::Record *getDef() const { return def; }
+
   // Returns whether two dialects are equal by checking the equality of the
   // underlying record.
   bool operator==(const Dialect &other) const;
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 599bb13190f12d..abd2733ba4fbd1 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -285,14 +285,17 @@ struct LowerGpuOpsToROCDLOpsPass
     configureGpuToROCDLConversionLegality(target);
     if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
       signalPassFailure();
-
+    auto *rocdlDialect = getContext().getLoadedDialect<ROCDL::ROCDLDialect>();
+    auto reqdWorkGroupSizeAttrHelper =
+        rocdlDialect->getReqdWorkGroupSizeAttrHelper();
+    auto flatWorkGroupSizeAttrHelper =
+        rocdlDialect->getFlatWorkGroupSizeAttrHelper();
     // Manually rewrite known block size attributes so the LLVMIR translation
     // infrastructure can pick them up.
-    m.walk([ctx](LLVM::LLVMFuncOp op) {
+    m.walk([&](LLVM::LLVMFuncOp op) {
       if (auto blockSizes = dyn_cast_or_null<DenseI32ArrayAttr>(
               op->removeAttr(gpu::GPUFuncOp::getKnownBlockSizeAttrName()))) {
-        op->setAttr(ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName(),
-                    blockSizes);
+        reqdWorkGroupSizeAttrHelper.setAttr(op, blockSizes);
         // Also set up the rocdl.flat_work_group_size attribute to prevent
         // conflicting metadata.
         uint32_t flatSize = 1;
@@ -301,8 +304,7 @@ struct LowerGpuOpsToROCDLOpsPass
         }
         StringAttr flatSizeAttr =
             StringAttr::get(ctx, Twine(flatSize) + "," + Twine(flatSize));
-        op->setAttr(ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName(),
-                    flatSizeAttr);
+        flatWorkGroupSizeAttrHelper.setAttr(op, flatSizeAttr);
       }
     });
   }
@@ -355,8 +357,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
       converter,
       /*allocaAddrSpace=*/ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace,
       /*workgroupAddrSpace=*/ROCDL::ROCDLDialect::kSharedMemoryAddressSpace,
-      StringAttr::get(&converter.getContext(),
-                      ROCDL::ROCDLDialect::getKernelFuncAttrName()));
+      ROCDL::ROCDLDialect::KernelAttrHelper(&converter.getContext()).getName());
   if (Runtime::HIP == runtime) {
     patterns.add<GPUPrintfOpToHIPLowering>(converter);
   } else if (Runtime::OpenCL == runtime) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 26e46b31ddc018..0f2e75cd7e8bc7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -253,9 +253,9 @@ void ROCDLDialect::initialize() {
 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
                                                      NamedAttribute attr) {
   // Kernel function attribute should be attached to functions.
-  if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
+  if (kernelAttrName.getName() == attr.getName()) {
     if (!isa<LLVM::LLVMFuncOp>(op)) {
-      return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
+      return op->emitError() << "'" << kernelAttrName.getName()
                              << "' attribute attached to unexpected op";
     }
   }
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 6924a2862eef07..081f6e56f9ded4 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,6 +106,10 @@ bool Dialect::usePropertiesForAttributes() const {
   return def->getValueAsBit("usePropertiesForAttributes");
 }
 
+llvm::DagInit *Dialect::getDiscardableAttributes() const {
+  return def->getValueAsDag("discardableAttrs");
+}
+
 bool Dialect::operator==(const Dialect &other) const {
   return def == other.def;
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 55a6285ec87eb4..6783ffcde6d531 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -84,7 +84,8 @@ class ROCDLDialectLLVMIRTranslationInterface
   amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
                  NamedAttribute attribute,
                  LLVM::ModuleTranslation &moduleTranslation) const final {
-    if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
+    auto *dialect = dyn_cast<ROCDL::ROCDLDialect>(attribute.getNameDialect());
+    if (dialect->getKernelAttrHelper().getName() == attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -106,7 +107,8 @@ class ROCDLDialectLLVMIRTranslationInterface
     // Override flat-work-group-size
     // TODO: update clients to rocdl.flat_work_group_size instead,
     // then remove this half of the branch
-    if ("rocdl.max_flat_work_group_size" == attribute.getName()) {
+    if (dialect->getMaxFlatWorkGroupSizeAttrHelper().getName() ==
+        attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
         return failure();
@@ -121,7 +123,7 @@ class ROCDLDialectLLVMIRTranslationInterface
       attrValueStream << "1," << value.getInt();
       llvmFunc->addFnAttr("amdgpu-flat-work-group-size", llvmAttrValue);
     }
-    if (ROCDL::ROCDLDialect::getFlatWorkGroupSizeAttrName() ==
+    if (dialect->getFlatWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
@@ -138,7 +140,7 @@ class ROCDLDialectLLVMIRTranslationInterface
     }
 
     // Set reqd_work_group_size metadata
-    if (ROCDL::ROCDLDialect::getReqdWorkGroupSizeAttrName() ==
+    if (dialect->getReqdWorkGroupSizeAttrHelper().getName() ==
         attribute.getName()) {
       auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
       if (!func)
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 8524d5b1458447..2b5491fc0c6a02 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -25,6 +25,10 @@ def Test_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 1;
   let isExtensible = 1;
   let dependentDialects = ["::mlir::DLTIDialect"];
+  let discardableAttrs = (ins
+     "mlir::IntegerAttr":$discardable_attr_key,
+     "SimpleAAttr":$other_discardable_attr_key
+  );
 
   let extraClassDeclaration = [{
     void registerAttributes();
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..a0ebd9ce8f29e3 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -43,6 +43,21 @@ using DialectFilterIterator =
                           std::function<bool(const llvm::Record *)>>;
 } // namespace
 
+static void populateDiscardableAttributes(
+    Dialect &dialect, llvm::DagInit *discardableAttrDag,
+    SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
+  for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
+    llvm::Init *arg = discardableAttrDag->getArg(i);
+
+    StringRef givenName = discardableAttrDag->getArgNameStr(i);
+    if (givenName.empty())
+      PrintFatalError(dialect.getDef()->getLoc(),
+                      "discardable attributes must be named");
+    discardableAttributes.push_back(
+        {givenName.str(), arg->getAsUnquotedString()});
+  }
+}
+
 /// Given a set of records for a T, filter the ones that correspond to
 /// the given dialect.
 template <typename T>
@@ -181,6 +196,37 @@ static const char *const operationInterfaceFallbackDecl = R"(
                                       mlir::OperationName opName) override;
 )";
 
+/// The code block for the discardable attribute helper.
+static const char *const discardableAttrHelperDecl = R"(
+    /// Helper to manage the discardable attribute `{1}`.
+    class {0}AttrHelper {{
+      mlir::StringAttr name;
+    public:
+      static constexpr llvm::StringLiteral getNameStr() {{
+        return "{4}.{1}";
+      }
+      constexpr mlir::StringAttr getName() {{
+        return name;
+      }
+
+      {0}AttrHelper(mlir::MLIRContext *ctx)
+        : name(mlir::StringAttr::get(ctx, getNameStr())) {{}
+
+     {2} getAttr(::mlir::Operation *op) {{
+       return op->getAttrOfType<{2}>(getName());
+     }
+     void setAttr(::mlir::Operation *op, {2} val) {{
+       op->setAttr(getName(), val);
+     }
+   };
+   {0}AttrHelper get{0}AttrHelper() {
+     return {3}AttrName;
+   }
+ private:
+   {0}AttrHelper {3}AttrName;
+ public:
+)";
+
 /// Generate the declaration for the given dialect class.
 static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
   // Emit all nested namespaces.
@@ -216,6 +262,22 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
       os << regionResultAttrVerifierDecl;
     if (dialect.hasOperationInterfaceFallback())
       os << operationInterfaceFallbackDecl;
+
+    llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+    SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+    populateDiscardableAttributes(dialect, discardableAttrDag,
+                                  discardableAttributes);
+
+    for (const auto &attrPair : discardableAttributes) {
+      std::string camelNameUpper = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/true);
+      std::string camelName = llvm::convertToCamelFromSnakeCase(
+          attrPair.first, /*capitalizeFirst=*/false);
+      os << llvm::formatv(discardableAttrHelperDecl, camelNameUpper,
+                          attrPair.first, attrPair.second, camelName,
+                          dialect.getName());
+    }
+
     if (std::optional<StringRef> extraDecl = dialect.getExtraClassDeclaration())
       os << *extraDecl;
 
@@ -253,9 +315,12 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
 /// {1}: initialization code that is emitted in the ctor body before calling
 ///      initialize().
 /// {2}: The dialect parent class.
+/// {3}: Extra members to initialize
 static const char *const dialectConstructorStr = R"(
 {0}::{0}(::mlir::MLIRContext *context)
-    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{
+    : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>())
+    {3}
+     {{
   {1}
   initialize();
 }
@@ -269,7 +334,9 @@ static const char *const dialectDestructorStr = R"(
 
 )";
 
-static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
+static void emitDialectDef(Dialect &dialect,
+                           const llvm::RecordKeeper &recordKeeper,
+                           raw_ostream &os) {
   std::string cppClassName = dialect.getCppClassName();
 
   // Emit the TypeID explicit specializations to have a single symbol def.
@@ -292,8 +359,22 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
   // Emit the constructor and destructor.
   StringRef superClassName =
       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
+
+  llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+  SmallVector<std::pair<std::string, std::string>> discardableAttributes;
+  populateDiscardableAttributes(dialect, discardableAttrDag,
+                                discardableAttributes);
+  std::string discardableAttributesInit;
+  for (const auto &attrPair : discardableAttributes) {
+    std::string camelName = llvm::convertToCamelFromSnakeCase(
+        attrPair.first, /*capitalizeFirst=*/false);
+    llvm::raw_string_ostream os(discardableAttributesInit);
+    os << ", " << camelName << "AttrName(context)";
+  }
+
   os << llvm::formatv(dialectConstructorStr, cppClassName,
-                      dependentDialectRegistrations, superClassName);
+                      dependentDialectRegistrations, superClassName,
+                      discardableAttributesInit);
   if (!dialect.hasNonDefaultDestructor())
     os << llvm::formatv(dialectDestructorStr, cppClassName);
 }
@@ -310,7 +391,7 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
   std::optional<Dialect> dialect = findDialectToGenerate(dialects);
   if (!dialect)
     return true;
-  emitDialectDef(*dialect, os);
+  emitDialectDef(*dialect, recordKeeper, os);
   return false;
 }
 

Copy link
Contributor

@sjw36 sjw36 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's it @joker-eph, Thanks! Appreciate it!

Just a minor ask.

@joker-eph joker-eph requested review from jpienaar and Mogball January 5, 2024 15:57
@@ -34,6 +34,8 @@ class Dialect {
// pattern or interfaces.
list<string> dependentDialects = [];

dag discardableAttrs = (ins);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my mental model correct that these are all effectively optional attributes for which there are helpers generated? But they are different than the others as not verified, most likely empty? (e.g., I'm not sure if doing this vs making these optional attributes when one would select which ...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if I followed correctly what you're asking about, but these are helpers to manage discardable attributes. The idea is that for example the ROCDL dialect involves a discardable attribute with the key rocdl.reqd_work_group_size, and the associated Attribute must be a DenseI32ArrayAttr.
This feature will generate a helper for managing this discardable attribute on operations. The operations don't know anything about the attribute, it's a property of the dialect itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me. These being dialect known discardable rather than op known optional.

@sjw36
Copy link
Contributor

sjw36 commented Feb 6, 2024

Looks like a spurious windows failure or just needs updated to TIP?

@joker-eph
Copy link
Collaborator Author

Sorry, I need to address the comments, I'm just temporarily out of a workstation! Feel free to ping next week

@sjw36
Copy link
Contributor

sjw36 commented Feb 7, 2024

Sorry, I need to address the comments, I'm just temporarily out of a workstation! Feel free to ping next week

No worries. I am suffering the same fate..

…e) attributes

This is a new ODS feature that allows dialects to define a list of
key/value pair representing an attribute type and a name.
This will generate helper classes on the dialect to be able to
manage discardable attributes on operations in a type safe way.

For example the `test` dialect can define:

```
  let discardableAttrs = (ins
     "mlir::IntegerAttr":$discardable_attr_key,
  );
```

And the following will be generated in the TestDialect class:

```
   /// Helper to manage the discardable attribute `discardable_attr_key`.
    class DiscardableAttrKeyAttrHelper {
      ::mlir::StringAttr name;
    public:
      static constexpr ::llvm::StringLiteral getNameStr() {
        return "test.discardable_attr_key";
      }
      constexpr ::mlir::StringAttr getName() {
        return name;
      }

      DiscardableAttrKeyAttrHelper(::mlir::MLIRContext *ctx)
        : name(::mlir::StringAttr::get(ctx, getNameStr())) {}

     mlir::IntegerAttr getAttr(::mlir::Operation *op) {
       return op->getAttrOfType<mlir::IntegerAttr>(name);
     }
     void setAttr(::mlir::Operation *op, mlir::IntegerAttr val) {
       op->setAttr(name, val);
     }
     bool isAttrPresent(::mlir::Operation *op) {
       return op->hasAttrOfType<mlir::IntegerAttr>(name);
     }
     void removeAttr(::mlir::Operation *op) {
       assert(op->hasAttrOfType<mlir::IntegerAttr>(name));
       op->removeAttr(name);
     }
   };
   DiscardableAttrKeyAttrHelper getDiscardableAttrKeyAttrHelper() {
     return discardableAttrKeyAttrName;
   }
```

User code having an instance of the TestDialect can then manipulate this
attribute on operation using:

```
  auto helper = testDialect.getDiscardableAttrKeyAttrHelper();

  helper.setAttr(op, value);
  helper.isAttrPresent(op);
  ...
```
@joker-eph joker-eph force-pushed the discardableAttrsHelper branch from f64d744 to 40a814a Compare February 20, 2024 02:04
@joker-eph joker-eph merged commit 45c226d into llvm:main Feb 20, 2024
@joker-eph joker-eph deleted the discardableAttrsHelper branch February 20, 2024 07:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants