Skip to content

[mlir] Add helper method to print and parse cyclic attributes and types #65210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,69 @@ class AsmPrinter {
printArrowTypeList(results);
}

/// Class used to automatically end a cyclic region on destruction.
class CyclicPrintReset {
public:
explicit CyclicPrintReset(AsmPrinter *printer) : printer(printer) {}

~CyclicPrintReset() {
if (printer)
printer->popCyclicPrinting();
}

CyclicPrintReset(const CyclicPrintReset &) = delete;

CyclicPrintReset &operator=(const CyclicPrintReset &) = delete;

CyclicPrintReset(CyclicPrintReset &&rhs)
: printer(std::exchange(rhs.printer, nullptr)) {}

CyclicPrintReset &operator=(CyclicPrintReset &&rhs) {
printer = std::exchange(rhs.printer, nullptr);
return *this;
}

private:
AsmPrinter *printer;
};

/// Attempts to start a cyclic printing region for `attrOrType`.
/// A cyclic printing region starts with this call and ends with the
/// destruction of the returned `CyclicPrintReset`. During this time,
/// calling `tryStartCyclicPrint` with the same attribute in any printer
/// will lead to returning failure.
///
/// This makes it possible to break infinite recursions when trying to print
/// cyclic attributes or types by printing only immutable parameters if nested
/// within itself.
template <class AttrOrTypeT>
FailureOr<CyclicPrintReset> tryStartCyclicPrint(AttrOrTypeT attrOrType) {
static_assert(
std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>,
AttrOrTypeT> ||
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
"Only mutable attributes or types can be cyclic");
if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer())))
return failure();
return CyclicPrintReset(this);
}

protected:
/// Initialize the printer with no internal implementation. In this case, all
/// virtual methods of this class must be overriden.
AsmPrinter() = default;

/// Pushes a new attribute or type in the form of a type erased pointer
/// into an internal set.
/// Returns success if the type or attribute was inserted in the set or
/// failure if it was already contained.
virtual LogicalResult pushCyclicPrinting(const void *opaquePointer);

/// Removes the element that was last inserted with a successful call to
/// `pushCyclicPrinting`. There must be exactly one `popCyclicPrinting` call
/// in reverse order of all successful `pushCyclicPrinting`.
virtual void popCyclicPrinting();

private:
AsmPrinter(const AsmPrinter &) = delete;
void operator=(const AsmPrinter &) = delete;
Expand Down Expand Up @@ -1265,12 +1323,67 @@ class AsmParser {
/// next token.
virtual ParseResult parseXInDimensionList() = 0;

/// Class used to automatically end a cyclic region on destruction.
class CyclicParseReset {
public:
explicit CyclicParseReset(AsmParser *parser) : parser(parser) {}

~CyclicParseReset() {
if (parser)
parser->popCyclicParsing();
}

CyclicParseReset(const CyclicParseReset &) = delete;
CyclicParseReset &operator=(const CyclicParseReset &) = delete;
CyclicParseReset(CyclicParseReset &&rhs)
: parser(std::exchange(rhs.parser, nullptr)) {}
CyclicParseReset &operator=(CyclicParseReset &&rhs) {
parser = std::exchange(rhs.parser, nullptr);
return *this;
}

private:
AsmParser *parser;
};

/// Attempts to start a cyclic parsing region for `attrOrType`.
/// A cyclic parsing region starts with this call and ends with the
/// destruction of the returned `CyclicParseReset`. During this time,
/// calling `tryStartCyclicParse` with the same attribute in any parser
/// will lead to returning failure.
///
/// This makes it possible to parse cyclic attributes or types by parsing a
/// short from if nested within itself.
template <class AttrOrTypeT>
FailureOr<CyclicParseReset> tryStartCyclicParse(AttrOrTypeT attrOrType) {
static_assert(
std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>,
AttrOrTypeT> ||
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
"Only mutable attributes or types can be cyclic");
if (failed(pushCyclicParsing(attrOrType.getAsOpaquePointer())))
return failure();

return CyclicParseReset(this);
}

protected:
/// Parse a handle to a resource within the assembly format for the given
/// dialect.
virtual FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) = 0;

/// Pushes a new attribute or type in the form of a type erased pointer
/// into an internal set.
/// Returns success if the type or attribute was inserted in the set or
/// failure if it was already contained.
virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0;

/// Removes the element that was last inserted with a successful call to
/// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call
/// in reverse order of all successful `pushCyclicParsing`.
virtual void popCyclicParsing() = 0;

//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,14 @@ class AsmParserImpl : public BaseT {
return parser.parseXInDimensionList();
}

LogicalResult pushCyclicParsing(const void *opaquePointer) override {
return success(parser.getState().cyclicParsingStack.insert(opaquePointer));
}

void popCyclicParsing() override {
parser.getState().cyclicParsingStack.pop_back();
}

//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/AsmParser/ParserState.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "Lexer.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringMap.h"

namespace mlir {
Expand Down Expand Up @@ -70,6 +71,10 @@ struct ParserState {
/// The current state for symbol parsing.
SymbolState &symbols;

/// Stack of potentially cyclic mutable attributes or type currently being
/// parsed.
SetVector<const void *> cyclicParsingStack;

/// An optional pointer to a struct containing high level parser state to be
/// populated during parsing.
AsmParserState *asmState;
Expand Down
69 changes: 27 additions & 42 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,27 +54,16 @@ static StringRef getTypeKeyword(Type type) {
/// Prints a structure type. Keeps track of known struct names to handle self-
/// or mutually-referring structs without falling into infinite recursion.
static void printStructType(AsmPrinter &printer, LLVMStructType type) {
// This keeps track of the names of identified structure types that are
// currently being printed. Since such types can refer themselves, this
// tracking is necessary to stop the recursion: the current function may be
// called recursively from AsmPrinter::printType after the appropriate
// dispatch. We maintain the invariant of this storage being modified
// exclusively in this function, and at most one name being added per call.
// TODO: consider having such functionality inside AsmPrinter.
thread_local SetVector<StringRef> knownStructNames;
unsigned stackSize = knownStructNames.size();
(void)stackSize;
auto guard = llvm::make_scope_exit([&]() {
assert(knownStructNames.size() == stackSize &&
"malformed identified stack when printing recursive structs");
});
FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;

printer << "<";
if (type.isIdentified()) {
cyclicPrint = printer.tryStartCyclicPrint(type);

printer << '"' << type.getName() << '"';
// If we are printing a reference to one of the enclosing structs, just
// print the name and stop to avoid infinitely long output.
if (knownStructNames.count(type.getName())) {
if (failed(cyclicPrint)) {
printer << '>';
return;
}
Expand All @@ -91,12 +80,8 @@ static void printStructType(AsmPrinter &printer, LLVMStructType type) {

// Put the current type on stack to avoid infinite recursion.
printer << '(';
if (type.isIdentified())
knownStructNames.insert(type.getName());
llvm::interleaveComma(type.getBody(), printer.getStream(),
[&](Type subtype) { dispatchPrint(printer, subtype); });
if (type.isIdentified())
knownStructNames.pop_back();
printer << ')';
printer << '>';
}
Expand Down Expand Up @@ -198,21 +183,6 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
/// | `struct<` string-literal `>`
/// | `struct<` string-literal `, opaque>`
static LLVMStructType parseStructType(AsmParser &parser) {
// This keeps track of the names of identified structure types that are
// currently being parsed. Since such types can refer themselves, this
// tracking is necessary to stop the recursion: the current function may be
// called recursively from AsmParser::parseType after the appropriate
// dispatch. We maintain the invariant of this storage being modified
// exclusively in this function, and at most one name being added per call.
// TODO: consider having such functionality inside AsmParser.
thread_local SetVector<StringRef> knownStructNames;
unsigned stackSize = knownStructNames.size();
(void)stackSize;
auto guard = llvm::make_scope_exit([&]() {
assert(knownStructNames.size() == stackSize &&
"malformed identified stack when parsing recursive structs");
});

Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());

if (failed(parser.parseLess()))
Expand All @@ -224,11 +194,18 @@ static LLVMStructType parseStructType(AsmParser &parser) {
std::string name;
bool isIdentified = succeeded(parser.parseOptionalString(&name));
if (isIdentified) {
if (knownStructNames.count(name)) {
if (failed(parser.parseGreater()))
return LLVMStructType();
return LLVMStructType::getIdentifiedChecked(
SMLoc greaterLoc = parser.getCurrentLocation();
if (succeeded(parser.parseOptionalGreater())) {
auto type = LLVMStructType::getIdentifiedChecked(
[loc] { return emitError(loc); }, loc.getContext(), name);
if (succeeded(parser.tryStartCyclicParse(type))) {
parser.emitError(
greaterLoc,
"struct without a body only allowed in a recursive struct");
return nullptr;
}

return type;
}
if (failed(parser.parseComma()))
return LLVMStructType();
Expand All @@ -251,6 +228,18 @@ static LLVMStructType parseStructType(AsmParser &parser) {
return type;
}

FailureOr<AsmParser::CyclicParseReset> cyclicParse;
if (isIdentified) {
cyclicParse =
parser.tryStartCyclicParse(LLVMStructType::getIdentifiedChecked(
[loc] { return emitError(loc); }, loc.getContext(), name));
if (failed(cyclicParse)) {
parser.emitError(kwLoc,
"identifier already used for an enclosing struct");
return nullptr;
}
}

// Check for packedness.
bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
if (failed(parser.parseLParen()))
Expand All @@ -273,14 +262,10 @@ static LLVMStructType parseStructType(AsmParser &parser) {
SmallVector<Type, 4> subtypes;
SMLoc subtypesLoc = parser.getCurrentLocation();
do {
if (isIdentified)
knownStructNames.insert(name);
Type type;
if (dispatchParse(parser, type))
return LLVMStructType();
subtypes.push_back(type);
if (isIdentified)
knownStructNames.pop_back();
} while (succeeded(parser.parseOptionalComma()));

if (parser.parseRParen() || parser.parseGreater())
Expand Down
Loading