Skip to content

[mlir] Add concept of alias blocks #65503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions mlir/docs/LangRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line.

```
// Top level production
toplevel := (operation | attribute-alias-def | type-alias-def)*
toplevel := (operation | alias-block-def)*
alias-block-def := (attribute-alias-def | type-alias-def)*
```

The production `toplevel` is the top level production that is parsed by any parsing
consuming the MLIR syntax. [Operations](#operations),
[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases)
consuming the MLIR syntax. [Operations](#operations) and
[Alias Blocks](#alias-block-definitions) consisting of
[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases)
can be declared on the toplevel.

### Identifiers and keywords
Expand Down Expand Up @@ -880,3 +882,26 @@ version using readAttribute and readType methods.
There is no restriction on what kind of information a dialect is allowed to
encode to model its versioning. Currently, versioning is supported only for
bytecode formats.

## Alias Block Definitions

An alias block is a list of subsequent attribute or type alias definitions that
are conceptually parsed as one unit.
This allows any alias definition within the block to reference any other alias
definition within the block, regardless if defined lexically later or earlier in
the block.

```mlir
// Alias block consisting of #array, !integer_type and #integer_attr.
#array = [#integer_attr, !integer_type]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If understand this correctly, it's basically a forward declaration?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It's essentially a use-before-def. Aliases within a block are allowed to reference other aliases within the block prior to their definition.
So in this case #integer_attr and !integer_type are declared afterwards, not prior.

If it were a forward declaration then I'd think it'd look something like:

// "forward" declaration
#array = []
!integer_type = i32
#integer_attr = 8 : !integer_type
#array = [#integer_attr, !integer_type]

this would in my opinion not be very intuative and for the larger issue I am trying to solve, that is using aliases with cyclic attributes and types, less ergonomic:

!type = !llvm.struct<"test">
!body = !llvm.ptr<!type>
!type = !llvm.struct<"test", (!body)>

instead of just

!type = !llvm.struct<"test", !body>
!body = !llvm.ptr<!type>

(llvm dialect only used for illustrative purposes)

This would also require an API/Interface for mutable attributes and types to return a "forward" declaration

!integer_type = i32
#integer_attr = 8 : !integer_type

// Illegal. !other_type is not part of this alias block and defined later
// in the file.
!tuple = tuple<i32, !other_type>

func.func @foo() { ... }

!other_type = f32
```
48 changes: 39 additions & 9 deletions mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) {
parseLocationInstance(locAttr) ||
parseToken(Token::r_paren, "expected ')' in inline location"))
return Attribute();

if (syntaxOnly())
return state.syntaxOnlyAttr;

return locAttr;
}

Expand Down Expand Up @@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return FloatAttr::get(floatType, *result);
}

if (syntaxOnly())
return state.syntaxOnlyAttr;

if (!isa<IntegerType, IndexType>(type))
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
Expand Down Expand Up @@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
return literalParser.getAttr(loc, type);
if (syntaxOnly())
return state.syntaxOnlyAttr;
return literalParser.getAttr(loc, cast<ShapedType>(type));
}

Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
Expand All @@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
return nullptr;
}

if (syntaxOnly())
return state.syntaxOnlyAttr;

ShapedType shapedType = dyn_cast<ShapedType>(attrType);
if (!shapedType) {
emitError(typeLoc, "`dense_resource` expected a shaped type");
Expand All @@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
/// elements-literal-type ::= vector-type | ranked-tensor-type
///
/// This method also checks the type has static shape.
ShapedType Parser::parseElementsLiteralType(Type type) {
Type Parser::parseElementsLiteralType(Type type) {
// If the user didn't provide a type, parse the colon type for the literal.
if (!type) {
if (parseToken(Token::colon, "expected ':'"))
Expand All @@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
return nullptr;
}

if (syntaxOnly())
return state.syntaxOnlyType;

auto sType = dyn_cast<ShapedType>(type);
if (!sType) {
emitError("elements literal must be a shaped type");
Expand All @@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// of the type.
Type indiceEltType = builder.getIntegerType(64);
if (consumeIf(Token::greater)) {
ShapedType type = parseElementsLiteralType(attrType);
Type type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;

if (syntaxOnly())
return state.syntaxOnlyAttr;

// Construct the sparse elements attr using zero element indice/value
// attributes.
ShapedType shapedType = cast<ShapedType>(type);
ShapedType indicesType =
RankedTensorType::get({0, type.getRank()}, indiceEltType);
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
RankedTensorType::get({0, shapedType.getRank()}, indiceEltType);
ShapedType valuesType =
RankedTensorType::get({0}, shapedType.getElementType());
return getChecked<SparseElementsAttr>(
loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
loc, shapedType,
DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
}

Expand All @@ -1114,14 +1135,20 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (!type)
return nullptr;

if (syntaxOnly())
return state.syntaxOnlyAttr;

ShapedType shapedType = cast<ShapedType>(type);

// If the indices are a splat, i.e. the literal parser parsed an element and
// not a list, we set the shape explicitly. The indices are represented by a
// 2-dimensional shape where the second dimension is the rank of the type.
// Given that the parsed indices is a splat, we know that we only have one
// indice and thus one for the first dimension.
ShapedType indicesType;
if (indiceParser.getShape().empty()) {
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
indicesType =
RankedTensorType::get({1, shapedType.getRank()}, indiceEltType);
} else {
// Otherwise, set the shape to the one parsed by the literal parser.
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
Expand All @@ -1131,15 +1158,15 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// If the values are a splat, set the shape explicitly based on the number of
// indices. The number of indices is encoded in the first dimension of the
// indice shape type.
auto valuesEltType = type.getElementType();
auto valuesEltType = shapedType.getElementType();
ShapedType valuesType =
valuesParser.getShape().empty()
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
auto values = valuesParser.getAttr(valuesLoc, valuesType);

// Build the sparse elements attribute by the indices and values.
return getChecked<SparseElementsAttr>(loc, type, indices, values);
return getChecked<SparseElementsAttr>(loc, shapedType, indices, values);
}

Attribute Parser::parseStridedLayoutAttr() {
Expand Down Expand Up @@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) {
return {};
}

if (syntaxOnly())
return state.syntaxOnlyAttr;

// Add the distinct attribute to the parser state, if it has not been parsed
// before. Otherwise, check if the parsed reference attribute matches the one
// found in the parser state.
Expand Down
41 changes: 36 additions & 5 deletions mlir/lib/AsmParser/DialectSymbolParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/ScopeExit.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This seems to be unused

#include "llvm/Support/SourceMgr.h"

using namespace mlir;
Expand Down Expand Up @@ -156,9 +157,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
}

/// Parse an extended dialect symbol.
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
template <typename Symbol, typename SymbolAliasMap, typename ParseAliasFn,
typename CreateFn>
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
SymbolAliasMap &aliases,
ParseAliasFn &parseAliasFn,
CreateFn &&createSymbol) {
Token tok = p.getToken();

Expand All @@ -185,12 +188,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
// If there is no '<' token following this, and if the typename contains no
// dot, then we are parsing a symbol alias.
if (!hasTrailingData && !isPrettyName) {

// Don't check the validity of alias reference in syntax-only mode.
if (p.syntaxOnly()) {
if constexpr (std::is_same_v<Symbol, Type>)
return p.getState().syntaxOnlyType;
else
return p.getState().syntaxOnlyAttr;
}

// Check for an alias for this type.
auto aliasIt = aliases.find(identifier);
if (aliasIt == aliases.end())
return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
"'"),
nullptr);
if (aliasIt == aliases.end()) {
FailureOr<Symbol> symbol = failure();
// Try the parse alias function if set.
if (parseAliasFn)
symbol = parseAliasFn(identifier);

if (failed(symbol)) {
p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'");
return nullptr;
}
if (!*symbol)
return nullptr;

aliasIt = aliases.insert({identifier, *symbol}).first;
}
if (asmState) {
if constexpr (std::is_same_v<Symbol, Type>)
asmState->addTypeAliasUses(identifier, range);
Expand Down Expand Up @@ -241,12 +264,16 @@ Attribute Parser::parseExtendedAttr(Type type) {
MLIRContext *ctx = getContext();
Attribute attr = parseExtendedSymbol<Attribute>(
*this, state.asmState, state.symbols.attributeAliasDefinitions,
state.symbols.parseUnknownAttributeAlias,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
// Parse an optional trailing colon type.
Type attrType = type;
if (consumeIf(Token::colon) && !(attrType = parseType()))
return Attribute();

if (syntaxOnly())
return state.syntaxOnlyAttr;

// If we found a registered dialect, then ask it to parse the attribute.
if (Dialect *dialect =
builder.getContext()->getOrLoadDialect(dialectName)) {
Expand Down Expand Up @@ -288,7 +315,11 @@ Type Parser::parseExtendedType() {
MLIRContext *ctx = getContext();
return parseExtendedSymbol<Type>(
*this, state.asmState, state.symbols.typeAliasDefinitions,
state.symbols.parseUnknownTypeAlias,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
if (syntaxOnly())
return state.syntaxOnlyType;

// If we found a registered dialect, then ask it to parse the type.
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
// Temporarily reset the lexer to let the dialect parse the type.
Expand Down
17 changes: 15 additions & 2 deletions mlir/lib/AsmParser/LocationParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
if (parseToken(Token::r_paren, "expected ')' in callsite location"))
return failure();

if (syntaxOnly())
return success();

// Return the callsite location.
loc = CallSiteLoc::get(calleeLoc, callerLoc);
return success();
Expand All @@ -79,6 +82,9 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
LocationAttr newLoc;
if (parseLocationInstance(newLoc))
return failure();
if (syntaxOnly())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very invasive. There is prior art to allow use-before-def, such as deferred aliases. Why won't that work for your use-case?

Copy link
Member Author

@zero9178 zero9178 Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ruled out the deferred aliases approach as not implementable.
The way it works for locations is that a placeholder opaque location is created that is then later RAUW with the actual location at the end of the module.

The reason I don't think is implementable is because todays attributes and types would not be able to handle having some placeholder returned from parseType or parseAttribute as they tend to check that these are of specific kinds (e.g. IntegerAttr or MemRefElementTypeInterface or fullfill some other specific constraints.
See also https://discourse.llvm.org/t/rfc-supporting-aliases-in-cyclic-types-and-attributes/73236#return-placeholder-attributes-or-types-10
This is not an issue for locations as they are just attached to Operations and do not need to fullfill any specific constraints

I agree that this is invasive and I'd love to hear better options. I initially considered writing separate skipType and skipAttr functions, but there I'd essentially just end up reimplementing the whole parser for the builtin attributes, including error messages, which would lead to a lot of code duplications. Any other way of skipping that isn't based on actually parsing the syntactic elements I'd consider more of a "hack" that is prone to breaking.

The silver lining is that this logic and invasiveness is private to the parser implementation and not user exposed and only part of the builtin attribute and type parsing, not any dialect parsing.

return success();

locations.push_back(newLoc);
return success();
};
Expand Down Expand Up @@ -135,12 +141,15 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
if (parseLocationInstance(childLoc))
return failure();

loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);

// Parse the closing ')'.
if (parseToken(Token::r_paren,
"expected ')' after child location of NameLoc"))
return failure();

if (syntaxOnly())
return success();

loc = NameLoc::get(StringAttr::get(ctx, str), childLoc);
} else {
loc = NameLoc::get(StringAttr::get(ctx, str));
}
Expand All @@ -154,6 +163,10 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
Attribute locAttr = parseExtendedAttr(Type());
if (!locAttr)
return failure();

if (syntaxOnly())
return success();

if (!(loc = dyn_cast<LocationAttr>(locAttr)))
return emitError("expected location attribute, but got") << locAttr;
return success();
Expand Down
Loading