Skip to content

[mlir] Add the ability to override attribute parsing/printing in attr-dicts #103304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Aug 13, 2024

This allows operations to use a custom parser for known attributes (i.e. attributes known in the op declaration), even if that attribute is part of the attr-dict.

One thing it allows for changing the types of attributes (in the attr-dict) without breaking familiar syntax.

For example, if you have:

vector.extract_strided_slice %vector { offsets = [0, 1, 2, 3], ... }

By default (using the generic parsing) offsets will be a ArrayAttr (of IntegerAttr). If you want to replace that ArrayAttr with a DenseI64ArrayAttr, by default that'll look like this:

vector.extract_strided_slice %vector { offsets = array<i64: 0, 1, 2, 3>, ... }

However, if you use the functionally added in the change, you can switch to a DenseI64ArrayAttr without changing the syntax.


This is done by adding two callbacks:

parseNamedAttrFn to AsmParser::parseOptionalAttrDict().

If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns failure() which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error.

and printNamedAttrFn to AsmPrinter::printOptionalAttrDict().

If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns success(), otherwise, it returns failure() which indicates that generic printing should be used.


Note: Currently this requires implementing a custom parser/printer for your operation. In future, it could be possible to generate the callbacks via the ODS.

…-dicts

This adds a `parseNamedAttrFn` callback to
`AsmParser::parseOptionalAttrDict()`.

If parseNamedAttrFn is provided the default parsing can be overridden
for a named attribute. parseNamedAttrFn is passed the name of an
attribute, if it can parse the attribute it returns the parsed
attribute, otherwise, it returns `failure()` which indicates that
generic parsing should be used. Note: Returning a null Attribute from
parseNamedAttrFn indicates a parser error.

It also adds `printNamedAttrFn` to
`AsmPrinter::printOptionalAttrDict()`.

If printNamedAttrFn is provided the default printing can be overridden
for a named attribute. printNamedAttrFn is passed a NamedAttribute, if
it prints the attribute it returns `success()`, otherwise, it returns
`failure()` which indicates that generic printing should be used.
@MacDue MacDue requested a review from banach-space August 13, 2024 16:10
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This adds a parseNamedAttrFn callback to AsmParser::parseOptionalAttrDict().

If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns failure() which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error.

It also adds printNamedAttrFn to AsmPrinter::printOptionalAttrDict().

If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns success(), otherwise, it returns failure() which indicates that generic printing should be used.


Full diff: https://github.com/llvm/llvm-project/pull/103304.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/OpImplementation.h (+19-4)
  • (modified) mlir/lib/AsmParser/AsmParserImpl.h (+5-2)
  • (modified) mlir/lib/AsmParser/AttributeParser.cpp (+14-2)
  • (modified) mlir/lib/AsmParser/Parser.h (+3-1)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+30-15)
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
   /// If the specified operation has attributes, print out an attribute
   /// dictionary with their values.  elidedAttrs allows the client to ignore
   /// specific well known attributes, commonly used if the attribute value is
-  /// printed some other way (like as a fixed operand).
+  /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+  /// provided the default printing can be overridden for a named attribute.
+  /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+  /// it returns `success()`, otherwise, it returns `failure()` which indicates
+  /// that generic printing should be used.
   virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                     ArrayRef<StringRef> elidedAttrs = {}) = 0;
+                                     ArrayRef<StringRef> elidedAttrs = {},
+                                     function_ref<LogicalResult(NamedAttribute)>
+                                         printNamedAttrFn = nullptr) = 0;
 
   /// If the specified operation has attributes, print out an attribute
   /// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
     return parseResult;
   }
 
-  /// Parse a named dictionary into 'result' if it is present.
-  virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+  /// Parse a named dictionary into 'result' if it is present. If
+  /// parseNamedAttrFn is provided the default parsing can be overridden for a
+  /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+  /// it can parse the attribute it returns the parsed attribute, otherwise, it
+  /// returns `failure()` which indicates that generic parsing should be used.
+  /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+  /// error.
+  virtual ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) = 0;
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
   /// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
   }
 
   /// Parse a named dictionary into 'result' if it is present.
-  ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+  ParseResult parseOptionalAttrDict(
+      NamedAttrList &result,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+          nullptr) override {
     if (parser.getToken().isNot(Token::l_brace))
       return success();
-    return parser.parseAttributeDict(result);
+    return parser.parseAttributeDict(result, parseNamedAttrFn);
   }
 
   /// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
 ///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+    NamedAttrList &attributes,
+    function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
   llvm::SmallDenseSet<StringAttr> seenKeys;
   auto parseElt = [&]() -> ParseResult {
     // The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
       return success();
     }
 
-    auto attr = parseAttribute();
+    Attribute attr = nullptr;
+    FailureOr<Attribute> customParsedAttribute;
+    // Try to parse with `printNamedAttrFn` callback.
+    if (parseNamedAttrFn &&
+        succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+      attr = *customParsedAttribute;
+    } else {
+      // Otherwise, use generic attribute parser.
+      attr = parseAttribute();
+    }
+
     if (!attr)
       return failure();
     attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
   }
 
   /// Parse an attribute dictionary.
-  ParseResult parseAttributeDict(NamedAttrList &attributes);
+  ParseResult parseAttributeDict(
+      NamedAttrList &attributes,
+      function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
 
   /// Parse a distinct attribute.
   Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
   void printDimensionList(ArrayRef<int64_t> shape);
 
 protected:
-  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {},
-                             bool withKeyword = false);
-  void printNamedAttribute(NamedAttribute attr);
+  void printOptionalAttrDict(
+      ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+      bool withKeyword = false,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+  void printNamedAttribute(
+      NamedAttribute attr,
+      function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
   void printTrailingLocation(Location loc, bool allowAlias = true);
   void printLocationInternal(LocationAttr loc, bool pretty = false,
                              bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
   /// Print the given set of attributes with names not included within
   /// 'elidedAttrs'.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    if (attrs.empty())
-      return;
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    (void)printNamedAttrFn;
     if (elidedAttrs.empty()) {
       for (const NamedAttribute &attr : attrs)
         printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
       .Default([&](Type type) { return printDialectType(type); });
 }
 
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                                             ArrayRef<StringRef> elidedAttrs,
-                                             bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+    ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+    bool withKeyword,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // If there are no attributes, then there is nothing to be done.
   if (attrs.empty())
     return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
 
     // Otherwise, print them all out in braces.
     os << " {";
-    interleaveComma(filteredAttrs,
-                    [&](NamedAttribute attr) { printNamedAttribute(attr); });
+    interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+      printNamedAttribute(attr, printNamedAttrFn);
+    });
     os << '}';
   };
 
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   if (!filteredAttrs.empty())
     printFilteredAttributesFn(filteredAttrs);
 }
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+    NamedAttribute attr,
+    function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
   // Print the name without quotes if possible.
   ::printKeywordOrString(attr.getName().strref(), os);
 
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
     return;
 
   os << " = ";
+  if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+    /// If we print via the `printNamedAttrFn` callback skip printing.
+    return;
+  }
   printAttribute(attr.getValue());
 }
 
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
 
   /// Print an optional attribute dictionary with a given set of elided values.
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
-                             ArrayRef<StringRef> elidedAttrs = {}) override {
-    Impl::printOptionalAttrDict(attrs, elidedAttrs);
+                             ArrayRef<StringRef> elidedAttrs = {},
+                             function_ref<LogicalResult(NamedAttribute)>
+                                 printNamedAttrFn = nullptr) override {
+    Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+                                printNamedAttrFn);
   }
   void printOptionalAttrDictWithKeyword(
       ArrayRef<NamedAttribute> attrs,

@banach-space
Copy link
Contributor

If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute,

What if we wanted to override printing for multiple named attributes?

@MacDue
Copy link
Member Author

MacDue commented Aug 14, 2024

If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute,

What if we wanted to override printing for multiple named attributes?

It's called for each attribute so just use a StringSwitch in your callback.

@joker-eph
Copy link
Collaborator

Can you expand on the motivation in the PR description? Also provide some tests: right now I'm not sure how it is to be used or what's the intent behind the feature.

Thanks!

@MacDue
Copy link
Member Author

MacDue commented Aug 14, 2024

Can you expand on the motivation in the PR description? Also provide some tests: right now I'm not sure how it is to be used or what's the intent behind the feature.

Thanks!

Added some tests now 👍
Please see the new summary for the motivation. Note: This initially came from #100336, but I think this is generally useful.

@joker-eph
Copy link
Collaborator

Thanks for elaborating!

I'm wondering if this is really that commonly needed or if it isn't instead quite "exotic"?

This can already be supported in custom C++ assembly, so may be what's missing instead is the ability to have a custom call-back for the declarative assembly attr-dict ?

As of #100336 ; we have some cleanup to do in terms of syntax as with the introduction of properties we should really move all the inherent attributes outside of the attr-dict printing anyway: which makes this feature even more exotic I think.

@MacDue
Copy link
Member Author

MacDue commented Aug 14, 2024

This can already be supported in custom C++ assembly, so may be what's missing instead is the ability to have a custom call-back for the declarative assembly attr-dict ?

This PR adds the easy way to implement this in custom C++ assembly (otherwise, you'd have to manually parse the whole attr-dict for each operation that wants something like this). I'm not sure I follow on the second point.

As of #100336 ; we have some cleanup to do in terms of syntax as with the introduction of properties we should really move all the inherent attributes outside of the attr-dict printing anyway: which makes this feature even more exotic I think.

I did propose moving the attribute out of the attr-dict, but people seemed generally against a change of syntax.

@joker-eph
Copy link
Collaborator

joker-eph commented Aug 14, 2024

This PR adds the easy way to implement this in custom C++ assembly

I don't think this PR is needed to implement it in custom C++: you can print { then iterate the dictionary and custom print anything there.

I'm not sure I follow on the second point.

Second point was that instead of having to do let hasCustomAssemblyFormat = 1;, one may want to continue using declarative assembly with custom directive just for the attr-dict : let assemblyFormat = "custom<MyPrintParse>(attr-dict)".

I did propose moving the attribute out of the attr-dict, but people seemed generally against a change of syntax.

This is something that we should just consider deprecated: this will really happen, I just have been slacking on this (prop-dict exists but isn't mandatory, yet).

@MacDue
Copy link
Member Author

MacDue commented Aug 14, 2024

I don't think this PR is needed to implement it in custom C++: you can print { then iterate the dictionary and custom print anything there.

I know it's not required, you could manually implement both the printing and parsing of the attr-dict. Still, I think it's easier to use what's already there, as there are some complications, particularly in generally parsing the attr-dict (which you can see in the implementation of parseAttributeDict()).

Second point was that instead of having to do let hasCustomAssemblyFormat = 1;, one may want to continue using declarative assembly with custom directive just for the attr-dict : let assemblyFormat = "custom<MyPrintParse>(attr-dict)".

I think having a declarative away to do something like this would be nice 🙂, though note that the first use case (the vector transfer ops), already has a custom printer/parser for other reasons. So, simply providing these callbacks is an easy change.

@joker-eph
Copy link
Collaborator

I looked at how you intend to use it, and now I'm even questioning the feature here: #100336 (comment)

@banach-space
Copy link
Contributor

I did propose moving the attribute out of the attr-dict, but people seemed generally against a change of syntax.

IMHO, we should avoid a situation where some named attributes are inside attr-dict, while others are outside:

vector.transfer_read %arg0[%0, %1], %cst, in_bounds = [true] { permutation_map = #map3} : memref<12x16xf32>, vector<8xf32>

Unless there's some rationale for that other than pretty-printing?

Btw, it sounds like we should take a step back here and look at the bigger picture:

we have some cleanup to do in terms of syntax as with the introduction of properties we should really move all the inherent attributes outside of the attr-dict printing anyway

I'll need to get back to your presentation on properties before I have an opinion. Going back to #100336, I worry that we might be too attached to the current syntax. This would be perfectly fine with me:

vector.transfer_read %arg0[%0, %1], %cst {in_bounds = array<i1: true>, permutation_map = #map3} : memref<12x16xf32>, vector<8xf32>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants