Skip to content

Commit c78733f

Browse files
committed
[MLIR] Flatten fused locations when merging constants.
[PR 74670](#74670) added support for merging locations at constant folding time. We have discovered that in some cases, the number of locations grows so big as to cause a compilation process to OOM. In that case, many of the locations end up appearing several times in nested fused locations. We add here a helper that always flattens fused locations in order to eliminate duplicates in the case of nested fused locations. We only allow flattening nested fused locations when the inner fused location has no metadata, or has the same metadata as the outer fused location.
1 parent 935c6a2 commit c78733f

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
331331
return newIt.first->second;
332332
}
333333

334+
/// Helper that flattens nested fused locations to a single fused location.
335+
/// Fused locations nested under non-fused locations are not flattened, and
336+
/// calling this on non-fused locations is a no-op as a result.
337+
///
338+
/// Fused locations are only flattened into parent fused locations if the
339+
/// child fused location has no metadata, or if the metadata of the parent and
340+
/// child fused locations are the same---this to avoid breaking cases where
341+
/// metadata matter.
342+
static Location FlattenFusedLocationRecursively(const Location loc) {
343+
if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
344+
SetVector<Location> flattenedLocs;
345+
Attribute metadata = fusedLoc.getMetadata();
346+
347+
for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
348+
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
349+
auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
350+
351+
if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
352+
flattenedFusedLoc.getMetadata() == metadata)) {
353+
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
354+
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
355+
} else {
356+
flattenedLocs.insert(flattenedLoc);
357+
}
358+
}
359+
360+
return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
361+
fusedLoc.getMetadata());
362+
}
363+
364+
return loc;
365+
}
366+
334367
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
335368
Location foldedLocation) {
336369
// Append into existing fused location if it has the same tag.
@@ -344,7 +377,7 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
344377
locations.insert(foldedLocation);
345378
Location newFusedLoc = FusedLoc::get(
346379
retainedOp->getContext(), locations.takeVector(), existingMetadata);
347-
retainedOp->setLoc(newFusedLoc);
380+
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
348381
return;
349382
}
350383
}
@@ -357,5 +390,5 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
357390
Location newFusedLoc =
358391
FusedLoc::get(retainedOp->getContext(),
359392
{retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
360-
retainedOp->setLoc(newFusedLoc);
393+
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
361394
}

mlir/test/Transforms/canonicalize-debuginfo.mlir

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
22

33
// CHECK-LABEL: func @merge_constants
4-
func.func @merge_constants() -> (index, index, index, index) {
4+
func.func @merge_constants() -> (index, index, index, index, index, index, index) {
55
// CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
66
%0 = arith.constant 42 : index loc("merge_constants":0:0)
77
%1 = arith.constant 42 : index loc("merge_constants":1:0)
88
%2 = arith.constant 42 : index loc("merge_constants":2:0)
99
%3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
10-
return %0, %1, %2, %3: index, index, index, index
10+
%4 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
11+
%5 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
12+
%6 = arith.constant 43 : index loc(fused<"some_other_label">["merge_constants":3:0])
13+
return %0, %1, %2, %3, %4, %5, %6 : index, index, index, index, index, index, index
1114
}
1215

1316
// CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
1417
// CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
1518
// CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
16-
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
19+
// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
20+
// CHECK-DAG: #[[FusedLoc_CSE_1:.*]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
21+
// CHECK-DAG: #[[FusedLoc_Some_Label:.*]] = loc(fused<"some_label">[#[[LocConst3]]])
22+
// CHECK-DAG: #[[FusedLoc_Some_Other_Label:.*]] = loc(fused<"some_other_label">[#[[LocConst3]]])
23+
// CHECK-DAG: #[[FusedLoc_CSE_2:.*]] = loc(fused<"CSE">[#[[FusedLoc_Some_Label]], #[[FusedLoc_Some_Other_Label]]])
1724

1825
// -----
1926

0 commit comments

Comments
 (0)