Skip to content

[mlir] Add the ability to override attribute parsing/printing in attr-dicts #103304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
/// printed some other way (like as a fixed operand).
/// printed some other way (like as a fixed operand). If printNamedAttrFn is
/// provided the default printing can be overridden for a named attribute.
/// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
/// it returns `success()`, otherwise, it returns `failure()` which indicates
/// that generic printing should be used.
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) = 0;
ArrayRef<StringRef> elidedAttrs = {},
function_ref<LogicalResult(NamedAttribute)>
printNamedAttrFn = nullptr) = 0;

/// If the specified operation has attributes, print out an attribute
/// dictionary prefixed with 'attributes'.
Expand Down Expand Up @@ -1116,8 +1122,17 @@ class AsmParser {
return parseResult;
}

/// Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
/// Parse a named dictionary into 'result' if it is present. If
/// parseNamedAttrFn is provided the default parsing can be overridden for a
/// named attribute. parseNamedAttrFn is passed the name of an attribute, if
/// it can parse the attribute it returns the parsed attribute, otherwise, it
/// returns `failure()` which indicates that generic parsing should be used.
/// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
/// error.
virtual ParseResult parseOptionalAttrDict(
NamedAttrList &result,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
nullptr) = 0;

/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
}

/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
ParseResult parseOptionalAttrDict(
NamedAttrList &result,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
return parser.parseAttributeDict(result);
return parser.parseAttributeDict(result, parseNamedAttrFn);
}

/// Parse a named dictionary into 'result' if the `attributes` keyword is
Expand Down
16 changes: 14 additions & 2 deletions mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
ParseResult Parser::parseAttributeDict(
NamedAttrList &attributes,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
Expand Down Expand Up @@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
}

auto attr = parseAttribute();
Attribute attr = nullptr;
FailureOr<Attribute> customParsedAttribute;
// Try to parse with `printNamedAttrFn` callback.
if (parseNamedAttrFn &&
succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
attr = *customParsedAttribute;
} else {
// Otherwise, use generic attribute parser.
attr = parseAttribute();
}

if (!attr)
return failure();
attributes.push_back({*nameId, attr});
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/AsmParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ class Parser {
}

/// Parse an attribute dictionary.
ParseResult parseAttributeDict(NamedAttrList &attributes);
ParseResult parseAttributeDict(
NamedAttrList &attributes,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);

/// Parse a distinct attribute.
Attribute parseDistinctAttr(Type type);
Expand Down
46 changes: 31 additions & 15 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,13 @@ class AsmPrinter::Impl {
void printDimensionList(ArrayRef<int64_t> shape);

protected:
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {},
bool withKeyword = false);
void printNamedAttribute(NamedAttribute attr);
void printOptionalAttrDict(
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
bool withKeyword = false,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printNamedAttribute(
NamedAttribute attr,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
Expand Down Expand Up @@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
if (attrs.empty())
return;
ArrayRef<StringRef> elidedAttrs = {},
function_ref<LogicalResult(NamedAttribute)>
printNamedAttrFn = nullptr) override {
(void)printNamedAttrFn;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
Expand Down Expand Up @@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}

void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs,
bool withKeyword) {
void AsmPrinter::Impl::printOptionalAttrDict(
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
bool withKeyword,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
Expand All @@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,

// Otherwise, print them all out in braces.
os << " {";
interleaveComma(filteredAttrs,
[&](NamedAttribute attr) { printNamedAttribute(attr); });
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
printNamedAttribute(attr, printNamedAttrFn);
});
os << '}';
};

Expand All @@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
void AsmPrinter::Impl::printNamedAttribute(
NamedAttribute attr,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);

Expand All @@ -2729,6 +2737,11 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;

os << " = ";
if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
/// If we print via the `printNamedAttrFn` callback, skip the generic
/// attribute printing (i.e. the call to `printAttribute`).
return;
}
printAttribute(attr.getValue());
}

Expand Down Expand Up @@ -3149,8 +3162,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {

/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs);
ArrayRef<StringRef> elidedAttrs = {},
function_ref<LogicalResult(NamedAttribute)>
printNamedAttrFn = nullptr) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
printNamedAttrFn);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: mlir-opt %s | FileCheck %s --check-prefix=CHECK-ROUNDTRIP
// RUN: mlir-opt %s -mlir-print-op-generic | FileCheck %s --check-prefix=CHECK-GENERIC-SYNTAX

/// This file tetss that "custom_dense_array" (which is a DenseArrayAttribute
/// stored within the attr-dict) is parsed and printed with the "pretty" array
/// syntax (i.e. `[1, 2, 3, 4]`), rather than with the generic dense array
/// syntax (`array<i64: 1, 2, 3, 4>`).
///
/// This is done by injecting custom parsing and printing callbacks into
/// parseOptionalAttrDict() and printOptionalAttrDict().

func.func @custom_attr_dict_syntax() {
// CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]}
// CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> : () -> ()
test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]}

// CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]}
// CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {another_attr = "foo"} : () -> ()
test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]}

// CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]}
// CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {default_array = [1, 2, 3, 4]} : () -> ()
test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]}

// CHECK-ROUND-TRIP: test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array<i64: 1, 2, 3, 4>, custom_dense_array = [1, 2, 3, 4]}
// CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array<i64: 1, 2, 3, 4>}> {default_dense_array = array<i64: 1, 2, 3, 4>} : () -> ()
test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array<i64: 1, 2, 3, 4>, custom_dense_array = [1, 2, 3, 4]}

return
}
31 changes: 31 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOpDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ void AffineScopeOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// CustomAttrParseAndPrintInAttrDict
//===----------------------------------------------------------------------===//

ParseResult CustomAttrParseAndPrintInAttrDict::parse(OpAsmParser &parser,
OperationState &result) {
return parser.parseOptionalAttrDict(
result.attributes, [&](StringRef name) -> FailureOr<Attribute> {
// Override the parsing for the "custom_dense_array" attribute in the
// attr-dict. Rather than parsing it as array<i64: 0, 1, 2, ...>, parse
// it as [0, 1, 2, ...] (i.e. using the standard array syntax).
if (name != getCustomDenseArrayAttrName(result.name))
return failure();
return DenseI64ArrayAttr::parse(parser, {});
});
}

void CustomAttrParseAndPrintInAttrDict::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(
(*this)->getAttrs(), {},
[&](NamedAttribute attrDictNamedAttribute) -> LogicalResult {
// Override the printing for the "custom_dense_array" attribute. Rather
// than printing it as array<i64: 0, 1, 2, ...>, print it as
// [0, 1, 2, ...] (i.e. using standard array syntax).
if (attrDictNamedAttribute.getName() != getCustomDenseArrayAttrName())
return failure();
cast<DenseI64ArrayAttr>(attrDictNamedAttribute.getValue()).print(p);
return success();
});
}

//===----------------------------------------------------------------------===//
// TestRemoveOpWithInnerOps
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2083,6 +2083,16 @@ def OptionalCustomAttrOp : TEST_Op<"optional_custom_attr"> {
}];
}

//===----------------------------------------------------------------------===//
// Test overriding attribute parsing/printing in the attr-dict via callbacks
// on parseOptionalAttrDict() and printOptionalAttrDict().

def CustomAttrParseAndPrintInAttrDict : TEST_Op<"custom_attr_parse_and_print_in_attr_dict">
{
let arguments = (ins DenseI64ArrayAttr:$custom_dense_array);
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Test OpAsmInterface.

Expand Down
Loading