Skip to content

Commit de3f7e2

Browse files
committed
[mlir] Fix infinite recursion in alias initializer
The alias initializer keeps a list of child indices around. When an alias is then marked as non-deferrable, all children are also marked non-deferrable. This is currently done naively which leads to an infinite recursion if using mutable types or attributes containing a cycle. This patch fixes this by adding an early return if the alias is already marked non-deferrable. Since this function is the only way to mark an alias as non-deferrable, it is guaranteed that if it is marked non-deferrable, all its children are as well, and it is not required to walk all the children. This incidentally makes the non-deferrable marking also `O(n)` instead of `O(n^2)` (although not performance sensitive obviously). Differential Revision: https://reviews.llvm.org/D158932
1 parent 57390c9 commit de3f7e2

File tree

6 files changed

+94
-3
lines changed

6 files changed

+94
-3
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,12 @@ std::pair<size_t, size_t> AliasInitializer::visitImpl(
10561056

10571057
void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
10581058
auto it = std::next(aliases.begin(), aliasIndex);
1059+
1060+
// If already marked non-deferrable stop the recursion.
1061+
// All children should already be marked non-deferrable as well.
1062+
if (!it->second.canBeDeferred)
1063+
return;
1064+
10591065
it->second.canBeDeferred = false;
10601066

10611067
// Propagate the non-deferrable flag to any child aliases.

mlir/test/IR/recursive-type.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
22

33
// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
4+
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
5+
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
46

57
// CHECK-LABEL: @roundtrip
68
func.func @roundtrip() {
@@ -12,6 +14,16 @@ func.func @roundtrip() {
1214
// into inifinite recursion.
1315
// CHECK: !testrec
1416
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<type_to_alias, test_rec<type_to_alias>>
17+
18+
// CHECK: () -> ![[$NAME]]
19+
// CHECK: () -> ![[$NAME]]
20+
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
21+
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name, !test.test_rec_alias<name>>
22+
23+
// CHECK: () -> ![[$NAME2]]
24+
// CHECK: () -> ![[$NAME2]]
25+
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
26+
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
1527
return
1628
}
1729

mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
218218
return AliasResult::FinalAlias;
219219
}
220220
}
221+
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
222+
os << recAliasType.getName();
223+
return AliasResult::FinalAlias;
224+
}
221225
return AliasResult::NoAlias;
222226
}
223227

mlir/test/lib/Dialect/Test/TestTypeDefs.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,4 +373,22 @@ def TestI32 : Test_Type<"TestI32"> {
373373
let mnemonic = "i32";
374374
}
375375

376+
def TestRecursiveAlias
377+
: Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> {
378+
let mnemonic = "test_rec_alias";
379+
let storageClass = "TestRecursiveTypeStorage";
380+
let storageNamespace = "test";
381+
let genStorageClass = 0;
382+
383+
let parameters = (ins "llvm::StringRef":$name);
384+
385+
let hasCustomAssemblyFormat = 1;
386+
387+
let extraClassDeclaration = [{
388+
Type getBody() const;
389+
390+
void setBody(Type type);
391+
}];
392+
}
393+
376394
#endif // TEST_TYPEDEFS

mlir/test/lib/Dialect/Test/TestTypes.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,54 @@ void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
482482
SetVector<Type> stack;
483483
printTestType(type, printer, stack);
484484
}
485+
486+
Type TestRecursiveAliasType::getBody() const { return getImpl()->body; }
487+
488+
void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); }
489+
490+
StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; }
491+
492+
Type TestRecursiveAliasType::parse(AsmParser &parser) {
493+
thread_local static SetVector<Type> stack;
494+
495+
StringRef name;
496+
if (parser.parseLess() || parser.parseKeyword(&name))
497+
return Type();
498+
auto rec = TestRecursiveAliasType::get(parser.getContext(), name);
499+
500+
// If this type already has been parsed above in the stack, expect just the
501+
// name.
502+
if (stack.contains(rec)) {
503+
if (failed(parser.parseGreater()))
504+
return Type();
505+
return rec;
506+
}
507+
508+
// Otherwise, parse the body and update the type.
509+
if (failed(parser.parseComma()))
510+
return Type();
511+
stack.insert(rec);
512+
Type subtype;
513+
if (parser.parseType(subtype))
514+
return nullptr;
515+
stack.pop_back();
516+
if (!subtype || failed(parser.parseGreater()))
517+
return Type();
518+
519+
rec.setBody(subtype);
520+
521+
return rec;
522+
}
523+
524+
void TestRecursiveAliasType::print(AsmPrinter &printer) const {
525+
thread_local static SetVector<Type> stack;
526+
527+
printer << "<" << getName();
528+
if (!stack.contains(*this)) {
529+
printer << ", ";
530+
stack.insert(*this);
531+
printer << getBody();
532+
stack.pop_back();
533+
}
534+
printer << ">";
535+
}

mlir/test/lib/Dialect/Test/TestTypes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ struct FieldParser<std::optional<int>> {
9191

9292
#include "TestTypeInterfaces.h.inc"
9393

94-
#define GET_TYPEDEF_CLASSES
95-
#include "TestTypeDefs.h.inc"
96-
9794
namespace test {
9895

9996
/// Storage for simple named recursive types, where the type is identified by
@@ -150,4 +147,7 @@ class TestRecursiveType
150147

151148
} // namespace test
152149

150+
#define GET_TYPEDEF_CLASSES
151+
#include "TestTypeDefs.h.inc"
152+
153153
#endif // MLIR_TESTTYPES_H

0 commit comments

Comments
 (0)