Skip to content

Commit 490cf97

Browse files
MacDuebanach-space
authored andcommitted
[mlir] Add the ability to override attribute parsing/printing in attr-dicts
This adds a `parseNamedAttrFn` callback to `AsmParser::parseOptionalAttrDict()`. If parseNamedAttrFn is provided the default parsing can be overridden for a named attribute. parseNamedAttrFn is passed the name of an attribute, if it can parse the attribute it returns the parsed attribute, otherwise, it returns `failure()` which indicates that generic parsing should be used. Note: Returning a null Attribute from parseNamedAttrFn indicates a parser error. It also adds `printNamedAttrFn` to `AsmPrinter::printOptionalAttrDict()`. If printNamedAttrFn is provided the default printing can be overridden for a named attribute. printNamedAttrFn is passed a NamedAttribute, if it prints the attribute it returns `success()`, otherwise, it returns `failure()` which indicates that generic printing should be used.
1 parent 93f5c61 commit 490cf97

File tree

5 files changed

+71
-24
lines changed

5 files changed

+71
-24
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
463463
/// If the specified operation has attributes, print out an attribute
464464
/// dictionary with their values. elidedAttrs allows the client to ignore
465465
/// specific well known attributes, commonly used if the attribute value is
466-
/// printed some other way (like as a fixed operand).
466+
/// printed some other way (like as a fixed operand). If printNamedAttrFn is
467+
/// provided the default printing can be overridden for a named attribute.
468+
/// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
469+
/// it returns `success()`, otherwise, it returns `failure()` which indicates
470+
/// that generic printing should be used.
467471
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
468-
ArrayRef<StringRef> elidedAttrs = {}) = 0;
472+
ArrayRef<StringRef> elidedAttrs = {},
473+
function_ref<LogicalResult(NamedAttribute)>
474+
printNamedAttrFn = nullptr) = 0;
469475

470476
/// If the specified operation has attributes, print out an attribute
471477
/// dictionary prefixed with 'attributes'.
@@ -1116,8 +1122,17 @@ class AsmParser {
11161122
return parseResult;
11171123
}
11181124

1119-
/// Parse a named dictionary into 'result' if it is present.
1120-
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
1125+
/// Parse a named dictionary into 'result' if it is present. If
1126+
/// parseNamedAttrFn is provided the default parsing can be overridden for a
1127+
/// named attribute. parseNamedAttrFn is passed the name of an attribute, if
1128+
/// it can parse the attribute it returns the parsed attribute, otherwise, it
1129+
/// returns `failure()` which indicates that generic parsing should be used.
1130+
/// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
1131+
/// error.
1132+
virtual ParseResult parseOptionalAttrDict(
1133+
NamedAttrList &result,
1134+
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
1135+
nullptr) = 0;
11211136

11221137
/// Parse a named dictionary into 'result' if the `attributes` keyword is
11231138
/// present.

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
458458
}
459459

460460
/// Parse a named dictionary into 'result' if it is present.
461-
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
461+
ParseResult parseOptionalAttrDict(
462+
NamedAttrList &result,
463+
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
464+
nullptr) override {
462465
if (parser.getToken().isNot(Token::l_brace))
463466
return success();
464-
return parser.parseAttributeDict(result);
467+
return parser.parseAttributeDict(result, parseNamedAttrFn);
465468
}
466469

467470
/// Parse a named dictionary into 'result' if the `attributes` keyword is

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
296296
/// | `{` attribute-entry (`,` attribute-entry)* `}`
297297
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
298298
///
299-
ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
299+
ParseResult Parser::parseAttributeDict(
300+
NamedAttrList &attributes,
301+
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
300302
llvm::SmallDenseSet<StringAttr> seenKeys;
301303
auto parseElt = [&]() -> ParseResult {
302304
// The name of an attribute can either be a bare identifier, or a string.
@@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
329331
return success();
330332
}
331333

332-
auto attr = parseAttribute();
334+
Attribute attr = nullptr;
335+
FailureOr<Attribute> customParsedAttribute;
336+
// Try to parse with `printNamedAttrFn` callback.
337+
if (parseNamedAttrFn &&
338+
succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
339+
attr = *customParsedAttribute;
340+
} else {
341+
// Otherwise, use generic attribute parser.
342+
attr = parseAttribute();
343+
}
344+
333345
if (!attr)
334346
return failure();
335347
attributes.push_back({*nameId, attr});

mlir/lib/AsmParser/Parser.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ class Parser {
256256
}
257257

258258
/// Parse an attribute dictionary.
259-
ParseResult parseAttributeDict(NamedAttrList &attributes);
259+
ParseResult parseAttributeDict(
260+
NamedAttrList &attributes,
261+
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);
260262

261263
/// Parse a distinct attribute.
262264
Attribute parseDistinctAttr(Type type);

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -452,10 +452,13 @@ class AsmPrinter::Impl {
452452
void printDimensionList(ArrayRef<int64_t> shape);
453453

454454
protected:
455-
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
456-
ArrayRef<StringRef> elidedAttrs = {},
457-
bool withKeyword = false);
458-
void printNamedAttribute(NamedAttribute attr);
455+
void printOptionalAttrDict(
456+
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
457+
bool withKeyword = false,
458+
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
459+
void printNamedAttribute(
460+
NamedAttribute attr,
461+
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
459462
void printTrailingLocation(Location loc, bool allowAlias = true);
460463
void printLocationInternal(LocationAttr loc, bool pretty = false,
461464
bool isTopLevel = false);
@@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
780783
/// Print the given set of attributes with names not included within
781784
/// 'elidedAttrs'.
782785
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
783-
ArrayRef<StringRef> elidedAttrs = {}) override {
784-
if (attrs.empty())
785-
return;
786+
ArrayRef<StringRef> elidedAttrs = {},
787+
function_ref<LogicalResult(NamedAttribute)>
788+
printNamedAttrFn = nullptr) override {
789+
(void)printNamedAttrFn;
786790
if (elidedAttrs.empty()) {
787791
for (const NamedAttribute &attr : attrs)
788792
printAttribute(attr.getValue());
@@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
26872691
.Default([&](Type type) { return printDialectType(type); });
26882692
}
26892693

2690-
void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2691-
ArrayRef<StringRef> elidedAttrs,
2692-
bool withKeyword) {
2694+
void AsmPrinter::Impl::printOptionalAttrDict(
2695+
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
2696+
bool withKeyword,
2697+
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
26932698
// If there are no attributes, then there is nothing to be done.
26942699
if (attrs.empty())
26952700
return;
@@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
27022707

27032708
// Otherwise, print them all out in braces.
27042709
os << " {";
2705-
interleaveComma(filteredAttrs,
2706-
[&](NamedAttribute attr) { printNamedAttribute(attr); });
2710+
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
2711+
printNamedAttribute(attr, printNamedAttrFn);
2712+
});
27072713
os << '}';
27082714
};
27092715

@@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
27202726
if (!filteredAttrs.empty())
27212727
printFilteredAttributesFn(filteredAttrs);
27222728
}
2723-
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2729+
void AsmPrinter::Impl::printNamedAttribute(
2730+
NamedAttribute attr,
2731+
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
27242732
// Print the name without quotes if possible.
27252733
::printKeywordOrString(attr.getName().strref(), os);
27262734

@@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
27292737
return;
27302738

27312739
os << " = ";
2740+
if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
2741+
/// If we print via the `printNamedAttrFn` callback skip printing.
2742+
return;
2743+
}
27322744
printAttribute(attr.getValue());
27332745
}
27342746

@@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
31493161

31503162
/// Print an optional attribute dictionary with a given set of elided values.
31513163
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
3152-
ArrayRef<StringRef> elidedAttrs = {}) override {
3153-
Impl::printOptionalAttrDict(attrs, elidedAttrs);
3164+
ArrayRef<StringRef> elidedAttrs = {},
3165+
function_ref<LogicalResult(NamedAttribute)>
3166+
printNamedAttrFn = nullptr) override {
3167+
Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
3168+
printNamedAttrFn);
31543169
}
31553170
void printOptionalAttrDictWithKeyword(
31563171
ArrayRef<NamedAttribute> attrs,

0 commit comments

Comments
 (0)