Skip to content

Commit 1b2c16f

Browse files
committed
[mlir][DeclarativeParser] Add support for attributes with buildable types.
This revision adds support in the declarative assembly form for printing attributes with buildable types without the type, and moves several more parsers over to the declarative form. Differential Revision: https://reviews.llvm.org/D74276
1 parent 327e062 commit 1b2c16f

File tree

8 files changed

+112
-42
lines changed

8 files changed

+112
-42
lines changed

mlir/docs/OpDefinitions.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,9 @@ A variable is an entity that has been registered on the operation itself, i.e.
616616
an argument(attribute or operand), result, etc. In the `CallOp` example above,
617617
the variables would be `$callee` and `$args`.
618618

619+
Attribute variables are printed with their respective value type, unless that
620+
value type is buildable. In those cases, the type of the attribute is elided.
621+
619622
#### Requirements
620623

621624
The format specification has a certain set of requirements that must be adhered

mlir/include/mlir/IR/OpBase.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,10 @@ class Attr<Pred condition, string descr = ""> :
629629
// Requires a constBuilderCall defined.
630630
string defaultValue = ?;
631631

632+
// The value type of this attribute. This corresponds to the mlir::Type that
633+
// this attribute returns via `getType()`.
634+
Type valueType = ?;
635+
632636
// Whether the attribute is optional. Typically requires a custom
633637
// convertFromStorage method to handle the case where the attribute is
634638
// not present.
@@ -660,6 +664,7 @@ class DefaultValuedAttr<Attr attr, string val> :
660664
let convertFromStorage = attr.convertFromStorage;
661665
let constBuilderCall = attr.constBuilderCall;
662666
let defaultValue = val;
667+
let valueType = attr.valueType;
663668

664669
let baseAttr = attr;
665670
}
@@ -673,6 +678,7 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
673678
let returnType = "Optional<" # attr.returnType #">";
674679
let convertFromStorage = "$_self ? " # returnType # "(" #
675680
attr.convertFromStorage # ") : (llvm::None)";
681+
let valueType = attr.valueType;
676682
let isOptional = 1;
677683

678684
let baseAttr = attr;
@@ -681,14 +687,15 @@ class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.description> {
681687
//===----------------------------------------------------------------------===//
682688
// Primitive attribute kinds
683689

684-
// A generic attribute that must be constructed around a specific type
690+
// A generic attribute that must be constructed around a specific buildable type
685691
// `attrValType`. Backed by MLIR attribute kind `attrKind`.
686-
class TypedAttrBase<BuildableType attrValType, string attrKind,
687-
Pred condition, string descr> :
692+
class TypedAttrBase<Type attrValType, string attrKind, Pred condition,
693+
string descr> :
688694
Attr<condition, descr> {
689695
let constBuilderCall = "$_builder.get" # attrKind # "(" #
690696
attrValType.builderCall # ", $0)";
691697
let storageType = attrKind;
698+
let valueType = attrValType;
692699
}
693700

694701
// Any attribute.
@@ -1227,6 +1234,7 @@ class Confined<Attr attr, list<AttrConstraint> constraints> : Attr<
12271234
let convertFromStorage = attr.convertFromStorage;
12281235
let constBuilderCall = attr.constBuilderCall;
12291236
let defaultValue = attr.defaultValue;
1237+
let valueType = attr.valueType;
12301238
let isOptional = attr.isOptional;
12311239

12321240
let baseAttr = attr;

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class OpAsmPrinter {
5858
virtual void printType(Type type) = 0;
5959
virtual void printAttribute(Attribute attr) = 0;
6060

61+
/// Print the given attribute without its type. The corresponding parser must
62+
/// provide a valid type for the attribute.
63+
virtual void printAttributeWithoutType(Attribute attr) = 0;
64+
6165
/// Print a successor, and use list, of a terminator operation given the
6266
/// terminator and the successor index.
6367
virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0;

mlir/include/mlir/TableGen/Attribute.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Record;
2525

2626
namespace mlir {
2727
namespace tblgen {
28+
class Type;
2829

2930
// Wrapper class with helper methods for accessing attribute constraints defined
3031
// in TableGen.
@@ -54,6 +55,10 @@ class Attribute : public AttrConstraint {
5455
// Returns the return type for this attribute.
5556
StringRef getReturnType() const;
5657

58+
// Return the type constraint corresponding to the type of this attribute, or
59+
// None if this is not a TypedAttr.
60+
llvm::Optional<Type> getValueType() const;
61+
5762
// Returns the template getter method call which reads this attribute's
5863
// storage and returns the value as of the desired return type.
5964
// The call will contain a `{0}` which will be expanded to this attribute.

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -855,10 +855,21 @@ class ModulePrinter {
855855
mlir::interleaveComma(c, os, each_fn);
856856
}
857857

858-
/// Print the given attribute. If 'mayElideType' is true, some attributes are
859-
/// printed without the type when the type matches the default used in the
860-
/// parser (for example i64 is the default for integer attributes).
861-
void printAttribute(Attribute attr, bool mayElideType = false);
858+
/// This enum descripes the different kinds of elision for the type of an
859+
/// attribute when printing it.
860+
enum class AttrTypeElision {
861+
/// The type must not be elided,
862+
Never,
863+
/// The type may be elided when it matches the default used in the parser
864+
/// (for example i64 is the default for integer attributes).
865+
May,
866+
/// The type must be elided.
867+
Must
868+
};
869+
870+
/// Print the given attribute.
871+
void printAttribute(Attribute attr,
872+
AttrTypeElision typeElision = AttrTypeElision::Never);
862873

863874
void printType(Type type);
864875
void printLocation(LocationAttr loc);
@@ -1185,7 +1196,8 @@ static void printElidedElementsAttr(raw_ostream &os) {
11851196
os << R"(opaque<"", "0xDEADBEEF">)";
11861197
}
11871198

1188-
void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
1199+
void ModulePrinter::printAttribute(Attribute attr,
1200+
AttrTypeElision typeElision) {
11891201
if (!attr) {
11901202
os << "<<NULL ATTRIBUTE>>";
11911203
return;
@@ -1200,6 +1212,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
12001212
}
12011213
}
12021214

1215+
auto attrType = attr.getType();
12031216
switch (attr.getKind()) {
12041217
default:
12051218
return printDialectAttribute(attr);
@@ -1236,12 +1249,11 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
12361249
case StandardAttributes::Integer: {
12371250
auto intAttr = attr.cast<IntegerAttr>();
12381251
// Print all integer attributes as signed unless i1.
1239-
bool isSigned = intAttr.getType().isIndex() ||
1240-
intAttr.getType().getIntOrFloatBitWidth() != 1;
1252+
bool isSigned = attrType.isIndex() || attrType.getIntOrFloatBitWidth() != 1;
12411253
intAttr.getValue().print(os, isSigned);
12421254

12431255
// IntegerAttr elides the type if I64.
1244-
if (mayElideType && intAttr.getType().isInteger(64))
1256+
if (typeElision == AttrTypeElision::May && attrType.isInteger(64))
12451257
return;
12461258
break;
12471259
}
@@ -1250,7 +1262,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
12501262
printFloatValue(floatAttr.getValue(), os);
12511263

12521264
// FloatAttr elides the type if F64.
1253-
if (mayElideType && floatAttr.getType().isF64())
1265+
if (typeElision == AttrTypeElision::May && attrType.isF64())
12541266
return;
12551267
break;
12561268
}
@@ -1262,7 +1274,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
12621274
case StandardAttributes::Array:
12631275
os << '[';
12641276
interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
1265-
printAttribute(attr, /*mayElideType=*/true);
1277+
printAttribute(attr, AttrTypeElision::May);
12661278
});
12671279
os << ']';
12681280
break;
@@ -1339,9 +1351,8 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
13391351
break;
13401352
}
13411353

1342-
// Print the type if it isn't a 'none' type.
1343-
auto attrType = attr.getType();
1344-
if (!attrType.isa<NoneType>()) {
1354+
// Don't print the type if we must elide it, or if it is a None type.
1355+
if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
13451356
os << " : ";
13461357
printType(attrType);
13471358
}
@@ -1904,6 +1915,12 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
19041915
ModulePrinter::printAttribute(attr);
19051916
}
19061917

1918+
/// Print the given attribute without its type. The corresponding parser must
1919+
/// provide a valid type for the attribute.
1920+
void printAttributeWithoutType(Attribute attr) override {
1921+
ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
1922+
}
1923+
19071924
/// Print the ID for the given value.
19081925
void printOperand(Value value) override { printValueID(value); }
19091926

mlir/lib/Parser/Parser.cpp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,14 @@ class Parser {
285285
Attribute parseDecOrHexAttr(Type type, bool isNegative);
286286

287287
/// Parse an opaque elements attribute.
288-
Attribute parseOpaqueElementsAttr();
288+
Attribute parseOpaqueElementsAttr(Type attrType);
289289

290290
/// Parse a dense elements attribute.
291-
Attribute parseDenseElementsAttr();
292-
ShapedType parseElementsLiteralType();
291+
Attribute parseDenseElementsAttr(Type attrType);
292+
ShapedType parseElementsLiteralType(Type type);
293293

294294
/// Parse a sparse elements attribute.
295-
Attribute parseSparseElementsAttr();
295+
Attribute parseSparseElementsAttr(Type attrType);
296296

297297
//===--------------------------------------------------------------------===//
298298
// Location Parsing
@@ -1505,7 +1505,7 @@ Attribute Parser::parseAttribute(Type type) {
15051505

15061506
// Parse a dense elements attribute.
15071507
case Token::kw_dense:
1508-
return parseDenseElementsAttr();
1508+
return parseDenseElementsAttr(type);
15091509

15101510
// Parse a dictionary attribute.
15111511
case Token::l_brace: {
@@ -1543,11 +1543,11 @@ Attribute Parser::parseAttribute(Type type) {
15431543

15441544
// Parse an opaque elements attribute.
15451545
case Token::kw_opaque:
1546-
return parseOpaqueElementsAttr();
1546+
return parseOpaqueElementsAttr(type);
15471547

15481548
// Parse a sparse elements attribute.
15491549
case Token::kw_sparse:
1550-
return parseSparseElementsAttr();
1550+
return parseSparseElementsAttr(type);
15511551

15521552
// Parse a string attribute.
15531553
case Token::string: {
@@ -1783,7 +1783,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
17831783
}
17841784

17851785
/// Parse an opaque elements attribute.
1786-
Attribute Parser::parseOpaqueElementsAttr() {
1786+
Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
17871787
consumeToken(Token::kw_opaque);
17881788
if (parseToken(Token::less, "expected '<' after 'opaque'"))
17891789
return nullptr;
@@ -1816,11 +1816,10 @@ Attribute Parser::parseOpaqueElementsAttr() {
18161816
return (emitError("opaque string only contains hex digits"), nullptr);
18171817

18181818
consumeToken(Token::string);
1819-
if (parseToken(Token::greater, "expected '>'") ||
1820-
parseToken(Token::colon, "expected ':'"))
1819+
if (parseToken(Token::greater, "expected '>'"))
18211820
return nullptr;
18221821

1823-
auto type = parseElementsLiteralType();
1822+
auto type = parseElementsLiteralType(attrType);
18241823
if (!type)
18251824
return nullptr;
18261825

@@ -2086,7 +2085,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
20862085
}
20872086

20882087
/// Parse a dense elements attribute.
2089-
Attribute Parser::parseDenseElementsAttr() {
2088+
Attribute Parser::parseDenseElementsAttr(Type attrType) {
20902089
consumeToken(Token::kw_dense);
20912090
if (parseToken(Token::less, "expected '<' after 'dense'"))
20922091
return nullptr;
@@ -2096,12 +2095,11 @@ Attribute Parser::parseDenseElementsAttr() {
20962095
if (literalParser.parse())
20972096
return nullptr;
20982097

2099-
if (parseToken(Token::greater, "expected '>'") ||
2100-
parseToken(Token::colon, "expected ':'"))
2098+
if (parseToken(Token::greater, "expected '>'"))
21012099
return nullptr;
21022100

21032101
auto typeLoc = getToken().getLoc();
2104-
auto type = parseElementsLiteralType();
2102+
auto type = parseElementsLiteralType(attrType);
21052103
if (!type)
21062104
return nullptr;
21072105
return literalParser.getAttr(typeLoc, type);
@@ -2112,10 +2110,14 @@ Attribute Parser::parseDenseElementsAttr() {
21122110
/// elements-literal-type ::= vector-type | ranked-tensor-type
21132111
///
21142112
/// This method also checks the type has static shape.
2115-
ShapedType Parser::parseElementsLiteralType() {
2116-
auto type = parseType();
2117-
if (!type)
2118-
return nullptr;
2113+
ShapedType Parser::parseElementsLiteralType(Type type) {
2114+
// If the user didn't provide a type, parse the colon type for the literal.
2115+
if (!type) {
2116+
if (parseToken(Token::colon, "expected ':'"))
2117+
return nullptr;
2118+
if (!(type = parseType()))
2119+
return nullptr;
2120+
}
21192121

21202122
if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
21212123
emitError("elements literal must be a ranked tensor or vector type");
@@ -2130,7 +2132,7 @@ ShapedType Parser::parseElementsLiteralType() {
21302132
}
21312133

21322134
/// Parse a sparse elements attribute.
2133-
Attribute Parser::parseSparseElementsAttr() {
2135+
Attribute Parser::parseSparseElementsAttr(Type attrType) {
21342136
consumeToken(Token::kw_sparse);
21352137
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
21362138
return nullptr;
@@ -2150,11 +2152,10 @@ Attribute Parser::parseSparseElementsAttr() {
21502152
if (valuesParser.parse())
21512153
return nullptr;
21522154

2153-
if (parseToken(Token::greater, "expected '>'") ||
2154-
parseToken(Token::colon, "expected ':'"))
2155+
if (parseToken(Token::greater, "expected '>'"))
21552156
return nullptr;
21562157

2157-
auto type = parseElementsLiteralType();
2158+
auto type = parseElementsLiteralType(attrType);
21582159
if (!type)
21592160
return nullptr;
21602161

mlir/lib/TableGen/Attribute.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ StringRef tblgen::Attribute::getReturnType() const {
7575
return getValueAsString(init);
7676
}
7777

78+
// Return the type constraint corresponding to the type of this attribute, or
79+
// None if this is not a TypedAttr.
80+
llvm::Optional<tblgen::Type> tblgen::Attribute::getValueType() const {
81+
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
82+
return tblgen::Type(defInit->getDef());
83+
return llvm::None;
84+
}
85+
7886
StringRef tblgen::Attribute::getConvertFromStorageCall() const {
7987
const auto *init = def->getValueInit("convertFromStorage");
8088
return getValueAsString(init);

0 commit comments

Comments
 (0)