Skip to content

Commit 0d1490f

Browse files
authored
[MLIR] Flatten fused locations when merging constants. (llvm#75218)
[PR 74670](llvm#74670) added support for merging locations at constant folding time. We have discovered that in some cases, the number of locations grows so big as to cause a compilation process to OOM. In that case, many of the locations end up appearing several times in nested fused locations. We add here a helper that always flattens fused locations in order to eliminate duplicates in the case of nested fused locations.
1 parent fe6f137 commit 0d1490f

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
331331
return newIt.first->second;
332332
}
333333

334+
/// Helper that flattens nested fused locations to a single fused location.
335+
/// Fused locations nested under non-fused locations are not flattened, and
336+
/// calling this on non-fused locations is a no-op as a result.
337+
///
338+
/// Fused locations are only flattened into parent fused locations if the
339+
/// child fused location has no metadata, or if the metadata of the parent and
340+
/// child fused locations are the same---this to avoid breaking cases where
341+
/// metadata matter.
342+
static Location FlattenFusedLocationRecursively(const Location loc) {
343+
if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
344+
SetVector<Location> flattenedLocs;
345+
Attribute metadata = fusedLoc.getMetadata();
346+
347+
for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
348+
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
349+
auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
350+
351+
if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
352+
flattenedFusedLoc.getMetadata() == metadata)) {
353+
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
354+
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
355+
} else {
356+
flattenedLocs.insert(flattenedLoc);
357+
}
358+
}
359+
360+
return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
361+
fusedLoc.getMetadata());
362+
}
363+
364+
return loc;
365+
}
366+
334367
void OperationFolder::appendFoldedLocation(Operation *retainedOp,
335368
Location foldedLocation) {
336369
// Append into existing fused location if it has the same tag.
@@ -344,7 +377,7 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
344377
locations.insert(foldedLocation);
345378
Location newFusedLoc = FusedLoc::get(
346379
retainedOp->getContext(), locations.takeVector(), existingMetadata);
347-
retainedOp->setLoc(newFusedLoc);
380+
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
348381
return;
349382
}
350383
}
@@ -357,5 +390,5 @@ void OperationFolder::appendFoldedLocation(Operation *retainedOp,
357390
Location newFusedLoc =
358391
FusedLoc::get(retainedOp->getContext(),
359392
{retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
360-
retainedOp->setLoc(newFusedLoc);
393+
retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
361394
}

mlir/test/Transforms/canonicalize-debuginfo.mlir

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
22

33
// CHECK-LABEL: func @merge_constants
4-
func.func @merge_constants() -> (index, index, index, index) {
4+
func.func @merge_constants() -> (index, index, index, index, index, index, index) {
55
// CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
66
%0 = arith.constant 42 : index loc("merge_constants":0:0)
77
%1 = arith.constant 42 : index loc("merge_constants":1:0)
88
%2 = arith.constant 42 : index loc("merge_constants":2:0)
99
%3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
10-
return %0, %1, %2, %3: index, index, index, index
10+
%4 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
11+
%5 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
12+
%6 = arith.constant 43 : index loc(fused<"some_other_label">["merge_constants":3:0])
13+
return %0, %1, %2, %3, %4, %5, %6 : index, index, index, index, index, index, index
1114
}
1215

1316
// CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
1417
// CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
1518
// CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
16-
// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
19+
// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
20+
// CHECK-DAG: #[[FusedLoc_CSE_1:.*]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
21+
// CHECK-DAG: #[[FusedLoc_Some_Label:.*]] = loc(fused<"some_label">[#[[LocConst3]]])
22+
// CHECK-DAG: #[[FusedLoc_Some_Other_Label:.*]] = loc(fused<"some_other_label">[#[[LocConst3]]])
23+
// CHECK-DAG: #[[FusedLoc_CSE_2:.*]] = loc(fused<"CSE">[#[[FusedLoc_Some_Label]], #[[FusedLoc_Some_Other_Label]]])
1724

1825
// -----
1926

0 commit comments

Comments
 (0)