Skip to content

Commit b121c26

Browse files
authored
[mlir] Add helper method to print and parse cyclic attributes and types (#65210)
Printing cyclic attributes and types currently has no first-class support within the AsmPrinter and AsmParser. The workaround for this issue used in all mutable attributes and types upstream has been to create a `thread_local static SetVector` keeping track of currently parsed and printed attributes. This solution is not ideal readability wise due to the use of globals and keeping track of state. Worst of all, this pattern had to be reimplemented for every mutable attribute and type. This patch therefore adds support for this pattern in `AsmPrinter` and `AsmParser` replacing the use of this pattern. By calling `tryStartCyclingPrint/Parse`, the mutable attribute or type are registered in an internal stack. All subsequent calls to the function with the same attribute or type will lead to returning failure. This way the nesting can be detected and a short form printed or parsed instead. Through the resetter returned by the call, the cyclic printing or parsing region automatically ends on return.
1 parent 8031a08 commit b121c26

File tree

9 files changed

+240
-118
lines changed

9 files changed

+240
-118
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,69 @@ class AsmPrinter {
222222
printArrowTypeList(results);
223223
}
224224

225+
/// Class used to automatically end a cyclic region on destruction.
226+
class CyclicPrintReset {
227+
public:
228+
explicit CyclicPrintReset(AsmPrinter *printer) : printer(printer) {}
229+
230+
~CyclicPrintReset() {
231+
if (printer)
232+
printer->popCyclicPrinting();
233+
}
234+
235+
CyclicPrintReset(const CyclicPrintReset &) = delete;
236+
237+
CyclicPrintReset &operator=(const CyclicPrintReset &) = delete;
238+
239+
CyclicPrintReset(CyclicPrintReset &&rhs)
240+
: printer(std::exchange(rhs.printer, nullptr)) {}
241+
242+
CyclicPrintReset &operator=(CyclicPrintReset &&rhs) {
243+
printer = std::exchange(rhs.printer, nullptr);
244+
return *this;
245+
}
246+
247+
private:
248+
AsmPrinter *printer;
249+
};
250+
251+
/// Attempts to start a cyclic printing region for `attrOrType`.
252+
/// A cyclic printing region starts with this call and ends with the
253+
/// destruction of the returned `CyclicPrintReset`. During this time,
254+
/// calling `tryStartCyclicPrint` with the same attribute in any printer
255+
/// will lead to returning failure.
256+
///
257+
/// This makes it possible to break infinite recursions when trying to print
258+
/// cyclic attributes or types by printing only immutable parameters if nested
259+
/// within itself.
260+
template <class AttrOrTypeT>
261+
FailureOr<CyclicPrintReset> tryStartCyclicPrint(AttrOrTypeT attrOrType) {
262+
static_assert(
263+
std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>,
264+
AttrOrTypeT> ||
265+
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
266+
"Only mutable attributes or types can be cyclic");
267+
if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer())))
268+
return failure();
269+
return CyclicPrintReset(this);
270+
}
271+
225272
protected:
226273
/// Initialize the printer with no internal implementation. In this case, all
227274
/// virtual methods of this class must be overriden.
228275
AsmPrinter() = default;
229276

277+
/// Pushes a new attribute or type in the form of a type erased pointer
278+
/// into an internal set.
279+
/// Returns success if the type or attribute was inserted in the set or
280+
/// failure if it was already contained.
281+
virtual LogicalResult pushCyclicPrinting(const void *opaquePointer);
282+
283+
/// Removes the element that was last inserted with a successful call to
284+
/// `pushCyclicPrinting`. There must be exactly one `popCyclicPrinting` call
285+
/// in reverse order of all successful `pushCyclicPrinting`.
286+
virtual void popCyclicPrinting();
287+
230288
private:
231289
AsmPrinter(const AsmPrinter &) = delete;
232290
void operator=(const AsmPrinter &) = delete;
@@ -1265,12 +1323,67 @@ class AsmParser {
12651323
/// next token.
12661324
virtual ParseResult parseXInDimensionList() = 0;
12671325

1326+
/// Class used to automatically end a cyclic region on destruction.
1327+
class CyclicParseReset {
1328+
public:
1329+
explicit CyclicParseReset(AsmParser *parser) : parser(parser) {}
1330+
1331+
~CyclicParseReset() {
1332+
if (parser)
1333+
parser->popCyclicParsing();
1334+
}
1335+
1336+
CyclicParseReset(const CyclicParseReset &) = delete;
1337+
CyclicParseReset &operator=(const CyclicParseReset &) = delete;
1338+
CyclicParseReset(CyclicParseReset &&rhs)
1339+
: parser(std::exchange(rhs.parser, nullptr)) {}
1340+
CyclicParseReset &operator=(CyclicParseReset &&rhs) {
1341+
parser = std::exchange(rhs.parser, nullptr);
1342+
return *this;
1343+
}
1344+
1345+
private:
1346+
AsmParser *parser;
1347+
};
1348+
1349+
/// Attempts to start a cyclic parsing region for `attrOrType`.
1350+
/// A cyclic parsing region starts with this call and ends with the
1351+
/// destruction of the returned `CyclicParseReset`. During this time,
1352+
/// calling `tryStartCyclicParse` with the same attribute in any parser
1353+
/// will lead to returning failure.
1354+
///
1355+
/// This makes it possible to parse cyclic attributes or types by parsing a
1356+
/// short from if nested within itself.
1357+
template <class AttrOrTypeT>
1358+
FailureOr<CyclicParseReset> tryStartCyclicParse(AttrOrTypeT attrOrType) {
1359+
static_assert(
1360+
std::is_base_of_v<AttributeTrait::IsMutable<AttrOrTypeT>,
1361+
AttrOrTypeT> ||
1362+
std::is_base_of_v<TypeTrait::IsMutable<AttrOrTypeT>, AttrOrTypeT>,
1363+
"Only mutable attributes or types can be cyclic");
1364+
if (failed(pushCyclicParsing(attrOrType.getAsOpaquePointer())))
1365+
return failure();
1366+
1367+
return CyclicParseReset(this);
1368+
}
1369+
12681370
protected:
12691371
/// Parse a handle to a resource within the assembly format for the given
12701372
/// dialect.
12711373
virtual FailureOr<AsmDialectResourceHandle>
12721374
parseResourceHandle(Dialect *dialect) = 0;
12731375

1376+
/// Pushes a new attribute or type in the form of a type erased pointer
1377+
/// into an internal set.
1378+
/// Returns success if the type or attribute was inserted in the set or
1379+
/// failure if it was already contained.
1380+
virtual LogicalResult pushCyclicParsing(const void *opaquePointer) = 0;
1381+
1382+
/// Removes the element that was last inserted with a successful call to
1383+
/// `pushCyclicParsing`. There must be exactly one `popCyclicParsing` call
1384+
/// in reverse order of all successful `pushCyclicParsing`.
1385+
virtual void popCyclicParsing() = 0;
1386+
12741387
//===--------------------------------------------------------------------===//
12751388
// Code Completion
12761389
//===--------------------------------------------------------------------===//

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,14 @@ class AsmParserImpl : public BaseT {
570570
return parser.parseXInDimensionList();
571571
}
572572

573+
LogicalResult pushCyclicParsing(const void *opaquePointer) override {
574+
return success(parser.getState().cyclicParsingStack.insert(opaquePointer));
575+
}
576+
577+
void popCyclicParsing() override {
578+
parser.getState().cyclicParsingStack.pop_back();
579+
}
580+
573581
//===--------------------------------------------------------------------===//
574582
// Code Completion
575583
//===--------------------------------------------------------------------===//

mlir/lib/AsmParser/ParserState.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "Lexer.h"
1313
#include "mlir/IR/Attributes.h"
1414
#include "mlir/IR/OpImplementation.h"
15+
#include "llvm/ADT/SetVector.h"
1516
#include "llvm/ADT/StringMap.h"
1617

1718
namespace mlir {
@@ -70,6 +71,10 @@ struct ParserState {
7071
/// The current state for symbol parsing.
7172
SymbolState &symbols;
7273

74+
/// Stack of potentially cyclic mutable attributes or type currently being
75+
/// parsed.
76+
SetVector<const void *> cyclicParsingStack;
77+
7378
/// An optional pointer to a struct containing high level parser state to be
7479
/// populated during parsing.
7580
AsmParserState *asmState;

mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,16 @@ static StringRef getTypeKeyword(Type type) {
5454
/// Prints a structure type. Keeps track of known struct names to handle self-
5555
/// or mutually-referring structs without falling into infinite recursion.
5656
static void printStructType(AsmPrinter &printer, LLVMStructType type) {
57-
// This keeps track of the names of identified structure types that are
58-
// currently being printed. Since such types can refer themselves, this
59-
// tracking is necessary to stop the recursion: the current function may be
60-
// called recursively from AsmPrinter::printType after the appropriate
61-
// dispatch. We maintain the invariant of this storage being modified
62-
// exclusively in this function, and at most one name being added per call.
63-
// TODO: consider having such functionality inside AsmPrinter.
64-
thread_local SetVector<StringRef> knownStructNames;
65-
unsigned stackSize = knownStructNames.size();
66-
(void)stackSize;
67-
auto guard = llvm::make_scope_exit([&]() {
68-
assert(knownStructNames.size() == stackSize &&
69-
"malformed identified stack when printing recursive structs");
70-
});
57+
FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
7158

7259
printer << "<";
7360
if (type.isIdentified()) {
61+
cyclicPrint = printer.tryStartCyclicPrint(type);
62+
7463
printer << '"' << type.getName() << '"';
7564
// If we are printing a reference to one of the enclosing structs, just
7665
// print the name and stop to avoid infinitely long output.
77-
if (knownStructNames.count(type.getName())) {
66+
if (failed(cyclicPrint)) {
7867
printer << '>';
7968
return;
8069
}
@@ -91,12 +80,8 @@ static void printStructType(AsmPrinter &printer, LLVMStructType type) {
9180

9281
// Put the current type on stack to avoid infinite recursion.
9382
printer << '(';
94-
if (type.isIdentified())
95-
knownStructNames.insert(type.getName());
9683
llvm::interleaveComma(type.getBody(), printer.getStream(),
9784
[&](Type subtype) { dispatchPrint(printer, subtype); });
98-
if (type.isIdentified())
99-
knownStructNames.pop_back();
10085
printer << ')';
10186
printer << '>';
10287
}
@@ -198,21 +183,6 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
198183
/// | `struct<` string-literal `>`
199184
/// | `struct<` string-literal `, opaque>`
200185
static LLVMStructType parseStructType(AsmParser &parser) {
201-
// This keeps track of the names of identified structure types that are
202-
// currently being parsed. Since such types can refer themselves, this
203-
// tracking is necessary to stop the recursion: the current function may be
204-
// called recursively from AsmParser::parseType after the appropriate
205-
// dispatch. We maintain the invariant of this storage being modified
206-
// exclusively in this function, and at most one name being added per call.
207-
// TODO: consider having such functionality inside AsmParser.
208-
thread_local SetVector<StringRef> knownStructNames;
209-
unsigned stackSize = knownStructNames.size();
210-
(void)stackSize;
211-
auto guard = llvm::make_scope_exit([&]() {
212-
assert(knownStructNames.size() == stackSize &&
213-
"malformed identified stack when parsing recursive structs");
214-
});
215-
216186
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
217187

218188
if (failed(parser.parseLess()))
@@ -224,11 +194,18 @@ static LLVMStructType parseStructType(AsmParser &parser) {
224194
std::string name;
225195
bool isIdentified = succeeded(parser.parseOptionalString(&name));
226196
if (isIdentified) {
227-
if (knownStructNames.count(name)) {
228-
if (failed(parser.parseGreater()))
229-
return LLVMStructType();
230-
return LLVMStructType::getIdentifiedChecked(
197+
SMLoc greaterLoc = parser.getCurrentLocation();
198+
if (succeeded(parser.parseOptionalGreater())) {
199+
auto type = LLVMStructType::getIdentifiedChecked(
231200
[loc] { return emitError(loc); }, loc.getContext(), name);
201+
if (succeeded(parser.tryStartCyclicParse(type))) {
202+
parser.emitError(
203+
greaterLoc,
204+
"struct without a body only allowed in a recursive struct");
205+
return nullptr;
206+
}
207+
208+
return type;
232209
}
233210
if (failed(parser.parseComma()))
234211
return LLVMStructType();
@@ -251,6 +228,18 @@ static LLVMStructType parseStructType(AsmParser &parser) {
251228
return type;
252229
}
253230

231+
FailureOr<AsmParser::CyclicParseReset> cyclicParse;
232+
if (isIdentified) {
233+
cyclicParse =
234+
parser.tryStartCyclicParse(LLVMStructType::getIdentifiedChecked(
235+
[loc] { return emitError(loc); }, loc.getContext(), name));
236+
if (failed(cyclicParse)) {
237+
parser.emitError(kwLoc,
238+
"identifier already used for an enclosing struct");
239+
return nullptr;
240+
}
241+
}
242+
254243
// Check for packedness.
255244
bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
256245
if (failed(parser.parseLParen()))
@@ -273,14 +262,10 @@ static LLVMStructType parseStructType(AsmParser &parser) {
273262
SmallVector<Type, 4> subtypes;
274263
SMLoc subtypesLoc = parser.getCurrentLocation();
275264
do {
276-
if (isIdentified)
277-
knownStructNames.insert(name);
278265
Type type;
279266
if (dispatchParse(parser, type))
280267
return LLVMStructType();
281268
subtypes.push_back(type);
282-
if (isIdentified)
283-
knownStructNames.pop_back();
284269
} while (succeeded(parser.parseOptionalComma()));
285270

286271
if (parser.parseRParen() || parser.parseGreater())

0 commit comments

Comments
 (0)