Skip to content

Commit 6fe3cd5

Browse files
authored
[MLIR][NFC] Add fast path to fused loc flattening. (#75312)
This is a follow-up on [PR 75218](#75218) that avoids reconstructing a fused loc in the `FlattenFusedLocationRecursively` helper when there has been no change.
1 parent 35dacf2 commit 6fe3cd5

File tree

1 file changed

+29
-18
lines changed

1 file changed

+29
-18
lines changed

mlir/lib/Transforms/Utils/FoldUtils.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
340340
/// child fused locations are the same---this to avoid breaking cases where
341341
/// metadata matter.
342342
static Location FlattenFusedLocationRecursively(const Location loc) {
343-
if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
344-
SetVector<Location> flattenedLocs;
345-
Attribute metadata = fusedLoc.getMetadata();
346-
347-
for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
348-
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
349-
auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
350-
351-
if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
352-
flattenedFusedLoc.getMetadata() == metadata)) {
353-
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
354-
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
355-
} else {
356-
flattenedLocs.insert(flattenedLoc);
357-
}
343+
auto fusedLoc = dyn_cast<FusedLoc>(loc);
344+
if (!fusedLoc)
345+
return loc;
346+
347+
SetVector<Location> flattenedLocs;
348+
Attribute metadata = fusedLoc.getMetadata();
349+
ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations();
350+
bool hasAnyNestedLocChanged = false;
351+
352+
for (const Location &unflattenedLoc : unflattenedLocs) {
353+
Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
354+
355+
auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
356+
if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
357+
flattenedFusedLoc.getMetadata() == metadata)) {
358+
hasAnyNestedLocChanged = true;
359+
ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
360+
flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
361+
} else {
362+
if (flattenedLoc != unflattenedLoc)
363+
hasAnyNestedLocChanged = true;
364+
365+
flattenedLocs.insert(flattenedLoc);
358366
}
367+
}
359368

360-
return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
361-
fusedLoc.getMetadata());
369+
if (!hasAnyNestedLocChanged &&
370+
unflattenedLocs.size() == flattenedLocs.size()) {
371+
return loc;
362372
}
363373

364-
return loc;
374+
return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
375+
fusedLoc.getMetadata());
365376
}
366377

367378
void OperationFolder::appendFoldedLocation(Operation *retainedOp,

0 commit comments

Comments
 (0)