-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add the ability to override attribute parsing/printing in attr-dicts #103304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…-dicts This adds a `parseNamedAttrFn` callback to `AsmParser::parseOptionalAttrDict()`. If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns `failure()` which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error. It also adds `printNamedAttrFn` to `AsmPrinter::printOptionalAttrDict()`. If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns `success()`, otherwise, it returns `failure()` which indicates that generic printing should be used.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis adds a If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns It also adds If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns Full diff: https://github.com/llvm/llvm-project/pull/103304.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index ae412c7227f8ea..5891cbffc9542d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
- /// printed some other way (like as a fixed operand).
+ /// printed some other way (like as a fixed operand). If printNamedAttrFn is
+ /// provided the default printing can be overridden for a named attribute.
+ /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
+ /// it returns `success()`, otherwise, it returns `failure()` which indicates
+ /// that generic printing should be used.
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) = 0;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) = 0;
/// If the specified operation has attributes, print out an attribute
/// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
return parseResult;
}
- /// Parse a named dictionary into 'result' if it is present.
- virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
+ /// Parse a named dictionary into 'result' if it is present. If
+ /// parseNamedAttrFn is provided the default parsing can be overridden for a
+ /// named attribute. parseNamedAttrFn is passed the name of an attribute, if
+ /// it can parse the attribute it returns the parsed attribute, otherwise, it
+ /// returns `failure()` which indicates that generic parsing should be used.
+ /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
+ /// error.
+ virtual ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) = 0;
/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index b12687833e3fde..808b2ca282f64b 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
}
/// Parse a named dictionary into 'result' if it is present.
- ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+ ParseResult parseOptionalAttrDict(
+ NamedAttrList &result,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
+ nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
- return parser.parseAttributeDict(result);
+ return parser.parseAttributeDict(result, parseNamedAttrFn);
}
/// Parse a named dictionary into 'result' if the `attributes` keyword is
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..b687d822e7cb7d 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ParseResult Parser::parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
}
- auto attr = parseAttribute();
+ Attribute attr = nullptr;
+ FailureOr<Attribute> customParsedAttribute;
+ // Try to parse with `printNamedAttrFn` callback.
+ if (parseNamedAttrFn &&
+ succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
+ attr = *customParsedAttribute;
+ } else {
+ // Otherwise, use generic attribute parser.
+ attr = parseAttribute();
+ }
+
if (!attr)
return failure();
attributes.push_back({*nameId, attr});
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 4caab499e1a0e4..d5d90f391fd391 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -256,7 +256,9 @@ class Parser {
}
/// Parse an attribute dictionary.
- ParseResult parseAttributeDict(NamedAttrList &attributes);
+ ParseResult parseAttributeDict(
+ NamedAttrList &attributes,
+ function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
/// Parse a distinct attribute.
Attribute parseDistinctAttr(Type type);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 02acc8c3f4659e..cd9f70c8868b83 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
void printDimensionList(ArrayRef<int64_t> shape);
protected:
- void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {},
- bool withKeyword = false);
- void printNamedAttribute(NamedAttribute attr);
+ void printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
+ bool withKeyword = false,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
+ void printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- if (attrs.empty())
- return;
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ (void)printNamedAttrFn;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}
-void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs,
- bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(
+ ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
+ bool withKeyword,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
// Otherwise, print them all out in braces.
os << " {";
- interleaveComma(filteredAttrs,
- [&](NamedAttribute attr) { printNamedAttribute(attr); });
+ interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
+ printNamedAttribute(attr, printNamedAttrFn);
+ });
os << '}';
};
@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
-void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(
+ NamedAttribute attr,
+ function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);
@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;
os << " = ";
+ if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
+ /// If we print via the `printNamedAttrFn` callback skip printing.
+ return;
+ }
printAttribute(attr.getValue());
}
@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs = {}) override {
- Impl::printOptionalAttrDict(attrs, elidedAttrs);
+ ArrayRef<StringRef> elidedAttrs = {},
+ function_ref<LogicalResult(NamedAttribute)>
+ printNamedAttrFn = nullptr) override {
+ Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
+ printNamedAttrFn);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
|
What if we wanted to override printing for multiple named attributes? |
It's called for each attribute so just use a StringSwitch in your callback. |
Can you expand on the motivation in the PR description? Also provide some tests: right now I'm not sure how it is to be used or what's the intent behind the feature. Thanks! |
Added some tests now 👍 |
Thanks for elaborating! I'm wondering if this is really that commonly needed or if it isn't instead quite "exotic"? This can already be supported in custom C++ assembly, so may be what's missing instead is the ability to have a custom call-back for the declarative assembly As of #100336 ; we have some cleanup to do in terms of syntax as with the introduction of properties we should really move all the inherent attributes outside of the |
This PR adds the easy way to implement this in custom C++ assembly (otherwise, you'd have to manually parse the whole
I did propose moving the attribute out of the |
I don't think this PR is needed to implement it in custom C++: you can print
Second point was that instead of having to do
This is something that we should just consider deprecated: this will really happen, I just have been slacking on this ( |
I know it's not required, you could manually implement both the printing and parsing of the
I think having a declarative away to do something like this would be nice 🙂, though note that the first use case (the vector transfer ops), already has a custom printer/parser for other reasons. So, simply providing these callbacks is an easy change. |
I looked at how you intend to use it, and now I'm even questioning the feature here: #100336 (comment) |
IMHO, we should avoid a situation where some named attributes are inside vector.transfer_read %arg0[%0, %1], %cst, in_bounds = [true] { permutation_map = #map3} : memref<12x16xf32>, vector<8xf32> Unless there's some rationale for that other than pretty-printing? Btw, it sounds like we should take a step back here and look at the bigger picture:
I'll need to get back to your presentation on properties before I have an opinion. Going back to #100336, I worry that we might be too attached to the current syntax. This would be perfectly fine with me: vector.transfer_read %arg0[%0, %1], %cst {in_bounds = array<i1: true>, permutation_map = #map3} : memref<12x16xf32>, vector<8xf32> |
This allows operations to use a custom parser for known attributes (i.e. attributes known in the op declaration), even if that attribute is part of the
attr-dict
.One thing it allows for changing the types of attributes (in the attr-dict) without breaking familiar syntax.
For example, if you have:
By default (using the generic parsing)
offsets
will be a ArrayAttr (of IntegerAttr). If you want to replace thatArrayAttr
with a DenseI64ArrayAttr, by default that'll look like this:However, if you use the functionally added in the change, you can switch to a DenseI64ArrayAttr without changing the syntax.
This is done by adding two callbacks:
parseNamedAttrFn
toAsmParser::parseOptionalAttrDict()
.If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns
failure()
which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error.and
printNamedAttrFn
toAsmPrinter::printOptionalAttrDict()
.If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns
success()
, otherwise, it returnsfailure()
which indicates that generic printing should be used.Note: Currently this requires implementing a custom parser/printer for your operation. In future, it could be possible to generate the callbacks via the ODS.