Skip to content

Commit 055e44d

Browse files
committed
[mlir] Add concept of alias blocks
This PR is part of https://discourse.llvm.org/t/rfc-supporting-aliases-in-cyclic-types-and-attributes/73236 It implements the concept of "alias blocks", a block of alias definitions which may alias any other alias definitions within the block, regardless of definition order. This is purely a convenience for immutable attributes and types, but is a requirement for supporting aliasing definitions in cyclic mutable attributes and types. The implementation works by first parsing an alias-block, which is simply subsequent alias definitions, in a syntax-only mode. This syntax-only mode only checks for syntactic validity of the parsed attribute or type but does not verify any parsed data. This allows us to essentially skip over alias definitions for the purpose of first collecting them and associating every alias definition with its source region. In a second pass, we can start parsing the attributes and types while at the same time attempting to resolve any unknown alias references with our list of yet-to-be-parsed attributes and types, parsing them on demand if required. A later PR will hook up this mechanism to the `tryStartCyclicParse` method added in b121c26 to early register cyclic attributes and types, breaking the parsing cycles.
1 parent 1a65cd3 commit 055e44d

File tree

9 files changed

+370
-86
lines changed

9 files changed

+370
-86
lines changed

mlir/docs/LangRef.md

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line.
183183

184184
```
185185
// Top level production
186-
toplevel := (operation | attribute-alias-def | type-alias-def)*
186+
toplevel := (operation | alias-block-def)*
187+
alias-block-def := (attribute-alias-def | type-alias-def)*
187188
```
188189

189190
The production `toplevel` is the top level production that is parsed by any parsing
190-
consuming the MLIR syntax. [Operations](#operations),
191-
[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases)
191+
consuming the MLIR syntax. [Operations](#operations) and
192+
[Alias Blocks](#alias-block-definitions) consisting of
193+
[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases)
192194
can be declared on the toplevel.
193195

194196
### Identifiers and keywords
@@ -880,3 +882,26 @@ version using readAttribute and readType methods.
880882
There is no restriction on what kind of information a dialect is allowed to
881883
encode to model its versioning. Currently, versioning is supported only for
882884
bytecode formats.
885+
886+
## Alias Block Definitions
887+
888+
An alias block is a list of subsequent attribute or type alias definitions that
889+
are conceptually parsed as one unit.
890+
This allows any alias definition within the block to reference any other alias
891+
definition within the block, regardless if defined lexically later or earlier in
892+
the block.
893+
894+
```mlir
895+
// Alias block consisting of #array, !integer_type and #integer_attr.
896+
#array = [#integer_attr, !integer_type]
897+
!integer_type = i32
898+
#integer_attr = 8 : !integer_type
899+
900+
// Illegal. !other_type is not part of this alias block and defined later
901+
// in the file.
902+
!tuple = tuple<i32, !other_type>
903+
904+
func.func @foo() { ... }
905+
906+
!other_type = f32
907+
```

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) {
145145
parseLocationInstance(locAttr) ||
146146
parseToken(Token::r_paren, "expected ')' in inline location"))
147147
return Attribute();
148+
149+
if (syntaxOnly())
150+
return state.syntaxOnlyAttr;
151+
148152
return locAttr;
149153
}
150154

@@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
430434
return FloatAttr::get(floatType, *result);
431435
}
432436

437+
if (syntaxOnly())
438+
return state.syntaxOnlyAttr;
439+
433440
if (!isa<IntegerType, IndexType>(type))
434441
return emitError(loc, "integer literal not valid for specified type"),
435442
nullptr;
@@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
10031010
auto type = parseElementsLiteralType(attrType);
10041011
if (!type)
10051012
return nullptr;
1006-
return literalParser.getAttr(loc, type);
1013+
if (syntaxOnly())
1014+
return state.syntaxOnlyAttr;
1015+
return literalParser.getAttr(loc, cast<ShapedType>(type));
10071016
}
10081017

10091018
Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
@@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
10301039
return nullptr;
10311040
}
10321041

1042+
if (syntaxOnly())
1043+
return state.syntaxOnlyAttr;
1044+
10331045
ShapedType shapedType = dyn_cast<ShapedType>(attrType);
10341046
if (!shapedType) {
10351047
emitError(typeLoc, "`dense_resource` expected a shaped type");
@@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
10441056
/// elements-literal-type ::= vector-type | ranked-tensor-type
10451057
///
10461058
/// This method also checks the type has static shape.
1047-
ShapedType Parser::parseElementsLiteralType(Type type) {
1059+
Type Parser::parseElementsLiteralType(Type type) {
10481060
// If the user didn't provide a type, parse the colon type for the literal.
10491061
if (!type) {
10501062
if (parseToken(Token::colon, "expected ':'"))
@@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
10531065
return nullptr;
10541066
}
10551067

1068+
if (syntaxOnly())
1069+
return state.syntaxOnlyType;
1070+
10561071
auto sType = dyn_cast<ShapedType>(type);
10571072
if (!sType) {
10581073
emitError("elements literal must be a shaped type");
@@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
10771092
// of the type.
10781093
Type indiceEltType = builder.getIntegerType(64);
10791094
if (consumeIf(Token::greater)) {
1080-
ShapedType type = parseElementsLiteralType(attrType);
1095+
Type type = parseElementsLiteralType(attrType);
10811096
if (!type)
10821097
return nullptr;
10831098

1099+
if (syntaxOnly())
1100+
return state.syntaxOnlyAttr;
1101+
10841102
// Construct the sparse elements attr using zero element indice/value
10851103
// attributes.
1104+
ShapedType shapedType = cast<ShapedType>(type);
10861105
ShapedType indicesType =
1087-
RankedTensorType::get({0, type.getRank()}, indiceEltType);
1088-
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1106+
RankedTensorType::get({0, shapedType.getRank()}, indiceEltType);
1107+
ShapedType valuesType =
1108+
RankedTensorType::get({0}, shapedType.getElementType());
10891109
return getChecked<SparseElementsAttr>(
1090-
loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1110+
loc, shapedType,
1111+
DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
10911112
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
10921113
}
10931114

@@ -1114,14 +1135,20 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
11141135
if (!type)
11151136
return nullptr;
11161137

1138+
if (syntaxOnly())
1139+
return state.syntaxOnlyAttr;
1140+
1141+
ShapedType shapedType = cast<ShapedType>(type);
1142+
11171143
// If the indices are a splat, i.e. the literal parser parsed an element and
11181144
// not a list, we set the shape explicitly. The indices are represented by a
11191145
// 2-dimensional shape where the second dimension is the rank of the type.
11201146
// Given that the parsed indices is a splat, we know that we only have one
11211147
// indice and thus one for the first dimension.
11221148
ShapedType indicesType;
11231149
if (indiceParser.getShape().empty()) {
1124-
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1150+
indicesType =
1151+
RankedTensorType::get({1, shapedType.getRank()}, indiceEltType);
11251152
} else {
11261153
// Otherwise, set the shape to the one parsed by the literal parser.
11271154
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
@@ -1131,15 +1158,15 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
11311158
// If the values are a splat, set the shape explicitly based on the number of
11321159
// indices. The number of indices is encoded in the first dimension of the
11331160
// indice shape type.
1134-
auto valuesEltType = type.getElementType();
1161+
auto valuesEltType = shapedType.getElementType();
11351162
ShapedType valuesType =
11361163
valuesParser.getShape().empty()
11371164
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
11381165
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
11391166
auto values = valuesParser.getAttr(valuesLoc, valuesType);
11401167

11411168
// Build the sparse elements attribute by the indices and values.
1142-
return getChecked<SparseElementsAttr>(loc, type, indices, values);
1169+
return getChecked<SparseElementsAttr>(loc, shapedType, indices, values);
11431170
}
11441171

11451172
Attribute Parser::parseStridedLayoutAttr() {
@@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) {
12601287
return {};
12611288
}
12621289

1290+
if (syntaxOnly())
1291+
return state.syntaxOnlyAttr;
1292+
12631293
// Add the distinct attribute to the parser state, if it has not been parsed
12641294
// before. Otherwise, check if the parsed reference attribute matches the one
12651295
// found in the parser state.

mlir/lib/AsmParser/DialectSymbolParser.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/BuiltinTypes.h"
1616
#include "mlir/IR/Dialect.h"
1717
#include "mlir/IR/DialectImplementation.h"
18+
#include "llvm/ADT/ScopeExit.h"
1819
#include "llvm/Support/SourceMgr.h"
1920

2021
using namespace mlir;
@@ -156,9 +157,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
156157
}
157158

158159
/// Parse an extended dialect symbol.
159-
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
160+
template <typename Symbol, typename SymbolAliasMap, typename ParseAliasFn,
161+
typename CreateFn>
160162
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
161163
SymbolAliasMap &aliases,
164+
ParseAliasFn &parseAliasFn,
162165
CreateFn &&createSymbol) {
163166
Token tok = p.getToken();
164167

@@ -185,12 +188,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
185188
// If there is no '<' token following this, and if the typename contains no
186189
// dot, then we are parsing a symbol alias.
187190
if (!hasTrailingData && !isPrettyName) {
191+
192+
// Don't check the validity of alias reference in syntax-only mode.
193+
if (p.syntaxOnly()) {
194+
if constexpr (std::is_same_v<Symbol, Type>)
195+
return p.getState().syntaxOnlyType;
196+
else
197+
return p.getState().syntaxOnlyAttr;
198+
}
199+
188200
// Check for an alias for this type.
189201
auto aliasIt = aliases.find(identifier);
190-
if (aliasIt == aliases.end())
191-
return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
192-
"'"),
193-
nullptr);
202+
if (aliasIt == aliases.end()) {
203+
FailureOr<Symbol> symbol = failure();
204+
// Try the parse alias function if set.
205+
if (parseAliasFn)
206+
symbol = parseAliasFn(identifier);
207+
208+
if (failed(symbol)) {
209+
p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'");
210+
return nullptr;
211+
}
212+
if (!*symbol)
213+
return nullptr;
214+
215+
aliasIt = aliases.insert({identifier, *symbol}).first;
216+
}
194217
if (asmState) {
195218
if constexpr (std::is_same_v<Symbol, Type>)
196219
asmState->addTypeAliasUses(identifier, range);
@@ -241,12 +264,16 @@ Attribute Parser::parseExtendedAttr(Type type) {
241264
MLIRContext *ctx = getContext();
242265
Attribute attr = parseExtendedSymbol<Attribute>(
243266
*this, state.asmState, state.symbols.attributeAliasDefinitions,
267+
state.symbols.parseUnknownAttributeAlias,
244268
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
245269
// Parse an optional trailing colon type.
246270
Type attrType = type;
247271
if (consumeIf(Token::colon) && !(attrType = parseType()))
248272
return Attribute();
249273

274+
if (syntaxOnly())
275+
return state.syntaxOnlyAttr;
276+
250277
// If we found a registered dialect, then ask it to parse the attribute.
251278
if (Dialect *dialect =
252279
builder.getContext()->getOrLoadDialect(dialectName)) {
@@ -288,7 +315,11 @@ Type Parser::parseExtendedType() {
288315
MLIRContext *ctx = getContext();
289316
return parseExtendedSymbol<Type>(
290317
*this, state.asmState, state.symbols.typeAliasDefinitions,
318+
state.symbols.parseUnknownTypeAlias,
291319
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
320+
if (syntaxOnly())
321+
return state.syntaxOnlyType;
322+
292323
// If we found a registered dialect, then ask it to parse the type.
293324
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
294325
// Temporarily reset the lexer to let the dialect parse the type.

mlir/lib/AsmParser/LocationParser.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
5353
if (parseToken(Token::r_paren, "expected ')' in callsite location"))
5454
return failure();
5555

56+
if (syntaxOnly())
57+
return success();
58+
5659
// Return the callsite location.
5760
loc = CallSiteLoc::get(calleeLoc, callerLoc);
5861
return success();
@@ -79,6 +82,9 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
7982
LocationAttr newLoc;
8083
if (parseLocationInstance(newLoc))
8184
return failure();
85+
if (syntaxOnly())
86+
return success();
87+
8288
locations.push_back(newLoc);
8389
return success();
8490
};
@@ -135,12 +141,15 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
135141
if (parseLocationInstance(childLoc))
136142
return failure();
137143

138-
loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
139-
140144
// Parse the closing ')'.
141145
if (parseToken(Token::r_paren,
142146
"expected ')' after child location of NameLoc"))
143147
return failure();
148+
149+
if (syntaxOnly())
150+
return success();
151+
152+
loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
144153
} else {
145154
loc = NameLoc::get(StringAttr::get(ctx, str));
146155
}
@@ -154,6 +163,10 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
154163
Attribute locAttr = parseExtendedAttr(Type());
155164
if (!locAttr)
156165
return failure();
166+
167+
if (syntaxOnly())
168+
return success();
169+
157170
if (!(loc = dyn_cast<LocationAttr>(locAttr)))
158171
return emitError("expected location attribute, but got") << locAttr;
159172
return success();

0 commit comments

Comments
 (0)