Skip to content

[mlir] Add support for parsing and printing cyclic aliases #66663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions mlir/docs/LangRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ starting with a `//` and going until the end of the line.

```
// Top level production
toplevel := (operation | attribute-alias-def | type-alias-def)*
toplevel := (operation | alias-block-def)*
alias-block-def := (attribute-alias-def | type-alias-def)*
```

The production `toplevel` is the top level production that is parsed by any parsing
consuming the MLIR syntax. [Operations](#operations),
[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases)
consuming the MLIR syntax. [Operations](#operations) and
[Alias Blocks](#alias-block-definitions) consisting of
[Attribute aliases](#attribute-value-aliases) and [Type aliases](#type-aliases)
can be declared on the toplevel.

### Identifiers and keywords
Expand Down Expand Up @@ -880,3 +882,26 @@ version using readAttribute and readType methods.
There is no restriction on what kind of information a dialect is allowed to
encode to model its versioning. Currently, versioning is supported only for
bytecode formats.

## Alias Block Definitions

An alias block is a list of subsequent attribute or type alias definitions that
are conceptually parsed as one unit.
This allows any alias definition within the block to reference any other alias
definition within the block, regardless if defined lexically later or earlier in
the block.

```mlir
// Alias block consisting of #array, !integer_type and #integer_attr.
#array = [#integer_attr, !integer_type]
!integer_type = i32
#integer_attr = 8 : !integer_type

// Illegal. !other_type is not part of this alias block and defined later
// in the file.
!tuple = tuple<i32, !other_type>

func.func @foo() { ... }

!other_type = f32
```
8 changes: 4 additions & 4 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ class AsmParser {
AttrOrTypeT> ||
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
"Only mutable attributes or types can be cyclic");
if (failed(pushCyclicParsing(attrOrType.getAsOpaquePointer())))
if (failed(pushCyclicParsing(attrOrType)))
return failure();

return CyclicParseReset(this);
Expand All @@ -1377,11 +1377,11 @@ class AsmParser {
virtual FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) = 0;

/// Pushes a new attribute or type in the form of a type erased pointer
/// into an internal set.
/// Pushes a new attribute or type into an internal set.
/// Returns success if the type or attribute was inserted in the set or
/// failure if it was already contained.
virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0;
virtual LogicalResult
pushCyclicParsing(PointerUnion<Attribute, Type> attrOrType) = 0;

/// Removes the element that was last inserted with a successful call to
/// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,9 @@ class AsmParserImpl : public BaseT {
return parser.parseXInDimensionList();
}

LogicalResult pushCyclicParsing(const void *opaquePointer) override {
return success(parser.getState().cyclicParsingStack.insert(opaquePointer));
LogicalResult
pushCyclicParsing(PointerUnion<Attribute, Type> attrOrType) override {
return success(parser.getState().cyclicParsingStack.insert(attrOrType));
}

void popCyclicParsing() override {
Expand Down
52 changes: 41 additions & 11 deletions mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using namespace mlir::detail;
/// | distinct-attribute
/// | extended-attribute
///
Attribute Parser::parseAttribute(Type type) {
Attribute Parser::parseAttribute(Type type, StringRef aliasDefName) {
switch (getToken().getKind()) {
// Parse an AffineMap or IntegerSet attribute.
case Token::kw_affine_map: {
Expand Down Expand Up @@ -117,7 +117,7 @@ Attribute Parser::parseAttribute(Type type) {

// Parse an extended attribute, i.e. alias or dialect attribute.
case Token::hash_identifier:
return parseExtendedAttr(type);
return parseExtendedAttr(type, aliasDefName);

// Parse floating point and integer attributes.
case Token::floatliteral:
Expand Down Expand Up @@ -145,6 +145,10 @@ Attribute Parser::parseAttribute(Type type) {
parseLocationInstance(locAttr) ||
parseToken(Token::r_paren, "expected ')' in inline location"))
return Attribute();

if (syntaxOnly())
return state.syntaxOnlyAttr;

return locAttr;
}

Expand Down Expand Up @@ -430,6 +434,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return FloatAttr::get(floatType, *result);
}

if (syntaxOnly())
return state.syntaxOnlyAttr;

if (!isa<IntegerType, IndexType>(type))
return emitError(loc, "integer literal not valid for specified type"),
nullptr;
Expand Down Expand Up @@ -1003,7 +1010,9 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
return literalParser.getAttr(loc, type);
if (syntaxOnly())
return state.syntaxOnlyAttr;
return literalParser.getAttr(loc, cast<ShapedType>(type));
}

Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
Expand All @@ -1030,6 +1039,9 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
return nullptr;
}

if (syntaxOnly())
return state.syntaxOnlyAttr;

ShapedType shapedType = dyn_cast<ShapedType>(attrType);
if (!shapedType) {
emitError(typeLoc, "`dense_resource` expected a shaped type");
Expand All @@ -1044,7 +1056,7 @@ Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
/// elements-literal-type ::= vector-type | ranked-tensor-type
///
/// This method also checks the type has static shape.
ShapedType Parser::parseElementsLiteralType(Type type) {
Type Parser::parseElementsLiteralType(Type type) {
// If the user didn't provide a type, parse the colon type for the literal.
if (!type) {
if (parseToken(Token::colon, "expected ':'"))
Expand All @@ -1053,6 +1065,9 @@ ShapedType Parser::parseElementsLiteralType(Type type) {
return nullptr;
}

if (syntaxOnly())
return state.syntaxOnlyType;

auto sType = dyn_cast<ShapedType>(type);
if (!sType) {
emitError("elements literal must be a shaped type");
Expand All @@ -1077,17 +1092,23 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// of the type.
Type indiceEltType = builder.getIntegerType(64);
if (consumeIf(Token::greater)) {
ShapedType type = parseElementsLiteralType(attrType);
Type type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;

if (syntaxOnly())
return state.syntaxOnlyAttr;

// Construct the sparse elements attr using zero element indice/value
// attributes.
ShapedType shapedType = cast<ShapedType>(type);
ShapedType indicesType =
RankedTensorType::get({0, type.getRank()}, indiceEltType);
ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
RankedTensorType::get({0, shapedType.getRank()}, indiceEltType);
ShapedType valuesType =
RankedTensorType::get({0}, shapedType.getElementType());
return getChecked<SparseElementsAttr>(
loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
loc, shapedType,
DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
}

Expand All @@ -1114,14 +1135,20 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (!type)
return nullptr;

if (syntaxOnly())
return state.syntaxOnlyAttr;

ShapedType shapedType = cast<ShapedType>(type);

// If the indices are a splat, i.e. the literal parser parsed an element and
// not a list, we set the shape explicitly. The indices are represented by a
// 2-dimensional shape where the second dimension is the rank of the type.
// Given that the parsed indices is a splat, we know that we only have one
// indice and thus one for the first dimension.
ShapedType indicesType;
if (indiceParser.getShape().empty()) {
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
indicesType =
RankedTensorType::get({1, shapedType.getRank()}, indiceEltType);
} else {
// Otherwise, set the shape to the one parsed by the literal parser.
indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
Expand All @@ -1131,15 +1158,15 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// If the values are a splat, set the shape explicitly based on the number of
// indices. The number of indices is encoded in the first dimension of the
// indice shape type.
auto valuesEltType = type.getElementType();
auto valuesEltType = shapedType.getElementType();
ShapedType valuesType =
valuesParser.getShape().empty()
? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
: RankedTensorType::get(valuesParser.getShape(), valuesEltType);
auto values = valuesParser.getAttr(valuesLoc, valuesType);

// Build the sparse elements attribute by the indices and values.
return getChecked<SparseElementsAttr>(loc, type, indices, values);
return getChecked<SparseElementsAttr>(loc, shapedType, indices, values);
}

Attribute Parser::parseStridedLayoutAttr() {
Expand Down Expand Up @@ -1260,6 +1287,9 @@ Attribute Parser::parseDistinctAttr(Type type) {
return {};
}

if (syntaxOnly())
return state.syntaxOnlyAttr;

// Add the distinct attribute to the parser state, if it has not been parsed
// before. Otherwise, check if the parsed reference attribute matches the one
// found in the parser state.
Expand Down
72 changes: 61 additions & 11 deletions mlir/lib/AsmParser/DialectSymbolParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/SourceMgr.h"

using namespace mlir;
Expand All @@ -28,18 +29,37 @@ namespace {
/// hooking into the main MLIR parsing logic.
class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
public:
CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
CustomDialectAsmParser(StringRef fullSpec, Parser &parser,
StringRef aliasDefName)
: AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
fullSpec(fullSpec) {}
fullSpec(fullSpec), aliasDefName(aliasDefName) {}
~CustomDialectAsmParser() override = default;

/// Returns the full specification of the symbol being parsed. This allows
/// for using a separate parser if necessary.
StringRef getFullSymbolSpec() const override { return fullSpec; }

LogicalResult
pushCyclicParsing(PointerUnion<Attribute, Type> attrOrType) override {
// If this is an alias definition, register the mutable attribute or type.
if (!aliasDefName.empty()) {
if (auto attr = dyn_cast<Attribute>(attrOrType))
parser.getState().symbols.attributeAliasDefinitions[aliasDefName] =
attr;
else
parser.getState().symbols.typeAliasDefinitions[aliasDefName] =
cast<Type>(attrOrType);
}
return AsmParserImpl::pushCyclicParsing(attrOrType);
}

private:
/// The full symbol specification.
StringRef fullSpec;

/// If this parser is used to parse an alias definition, the name of the alias
/// definition. Empty otherwise.
StringRef aliasDefName;
};
} // namespace

Expand Down Expand Up @@ -156,9 +176,11 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
}

/// Parse an extended dialect symbol.
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
template <typename Symbol, typename SymbolAliasMap, typename ParseAliasFn,
typename CreateFn>
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
SymbolAliasMap &aliases,
ParseAliasFn &parseAliasFn,
CreateFn &&createSymbol) {
Token tok = p.getToken();

Expand All @@ -185,12 +207,32 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
// If there is no '<' token following this, and if the typename contains no
// dot, then we are parsing a symbol alias.
if (!hasTrailingData && !isPrettyName) {

// Don't check the validity of alias reference in syntax-only mode.
if (p.syntaxOnly()) {
if constexpr (std::is_same_v<Symbol, Type>)
return p.getState().syntaxOnlyType;
else
return p.getState().syntaxOnlyAttr;
}

// Check for an alias for this type.
auto aliasIt = aliases.find(identifier);
if (aliasIt == aliases.end())
return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
"'"),
nullptr);
if (aliasIt == aliases.end()) {
FailureOr<Symbol> symbol = failure();
// Try the parse alias function if set.
if (parseAliasFn)
symbol = parseAliasFn(identifier);

if (failed(symbol)) {
p.emitWrongTokenError("undefined symbol alias id '" + identifier + "'");
return nullptr;
}
if (!*symbol)
return nullptr;

aliasIt = aliases.insert({identifier, *symbol}).first;
}
if (asmState) {
if constexpr (std::is_same_v<Symbol, Type>)
asmState->addTypeAliasUses(identifier, range);
Expand Down Expand Up @@ -237,16 +279,20 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
/// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
/// attribute-alias ::= `#` alias-name
///
Attribute Parser::parseExtendedAttr(Type type) {
Attribute Parser::parseExtendedAttr(Type type, StringRef aliasDefName) {
MLIRContext *ctx = getContext();
Attribute attr = parseExtendedSymbol<Attribute>(
*this, state.asmState, state.symbols.attributeAliasDefinitions,
state.symbols.parseUnknownAttributeAlias,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
// Parse an optional trailing colon type.
Type attrType = type;
if (consumeIf(Token::colon) && !(attrType = parseType()))
return Attribute();

if (syntaxOnly())
return state.syntaxOnlyAttr;

// If we found a registered dialect, then ask it to parse the attribute.
if (Dialect *dialect =
builder.getContext()->getOrLoadDialect(dialectName)) {
Expand All @@ -255,7 +301,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
resetToken(symbolData.data());

// Parse the attribute.
CustomDialectAsmParser customParser(symbolData, *this);
CustomDialectAsmParser customParser(symbolData, *this, aliasDefName);
Attribute attr = dialect->parseAttribute(customParser, attrType);
resetToken(curLexerPos);
return attr;
Expand Down Expand Up @@ -284,19 +330,23 @@ Attribute Parser::parseExtendedAttr(Type type) {
/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
/// type-alias ::= `!` alias-name
///
Type Parser::parseExtendedType() {
Type Parser::parseExtendedType(StringRef aliasDefName) {
MLIRContext *ctx = getContext();
return parseExtendedSymbol<Type>(
*this, state.asmState, state.symbols.typeAliasDefinitions,
state.symbols.parseUnknownTypeAlias,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
if (syntaxOnly())
return state.syntaxOnlyType;

// If we found a registered dialect, then ask it to parse the type.
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
// Temporarily reset the lexer to let the dialect parse the type.
const char *curLexerPos = getToken().getLoc().getPointer();
resetToken(symbolData.data());

// Parse the type.
CustomDialectAsmParser customParser(symbolData, *this);
CustomDialectAsmParser customParser(symbolData, *this, aliasDefName);
Type type = dialect->parseType(customParser);
resetToken(curLexerPos);
return type;
Expand Down
Loading