@@ -340,28 +340,39 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
340
340
// / child fused locations are the same---this to avoid breaking cases where
341
341
// / metadata matter.
342
342
static Location FlattenFusedLocationRecursively (const Location loc) {
343
- if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
344
- SetVector<Location> flattenedLocs;
345
- Attribute metadata = fusedLoc.getMetadata ();
346
-
347
- for (const Location &unflattenedLoc : fusedLoc.getLocations ()) {
348
- Location flattenedLoc = FlattenFusedLocationRecursively (unflattenedLoc);
349
- auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
350
-
351
- if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata () ||
352
- flattenedFusedLoc.getMetadata () == metadata)) {
353
- ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations ();
354
- flattenedLocs.insert (nestedLocations.begin (), nestedLocations.end ());
355
- } else {
356
- flattenedLocs.insert (flattenedLoc);
357
- }
343
+ auto fusedLoc = dyn_cast<FusedLoc>(loc);
344
+ if (!fusedLoc)
345
+ return loc;
346
+
347
+ SetVector<Location> flattenedLocs;
348
+ Attribute metadata = fusedLoc.getMetadata ();
349
+ ArrayRef<Location> unflattenedLocs = fusedLoc.getLocations ();
350
+ bool hasAnyNestedLocChanged = false ;
351
+
352
+ for (const Location &unflattenedLoc : unflattenedLocs) {
353
+ Location flattenedLoc = FlattenFusedLocationRecursively (unflattenedLoc);
354
+
355
+ auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
356
+ if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata () ||
357
+ flattenedFusedLoc.getMetadata () == metadata)) {
358
+ hasAnyNestedLocChanged = true ;
359
+ ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations ();
360
+ flattenedLocs.insert (nestedLocations.begin (), nestedLocations.end ());
361
+ } else {
362
+ if (flattenedLoc != unflattenedLoc)
363
+ hasAnyNestedLocChanged = true ;
364
+
365
+ flattenedLocs.insert (flattenedLoc);
358
366
}
367
+ }
359
368
360
- return FusedLoc::get (loc->getContext (), flattenedLocs.takeVector (),
361
- fusedLoc.getMetadata ());
369
+ if (!hasAnyNestedLocChanged &&
370
+ unflattenedLocs.size () == flattenedLocs.size ()) {
371
+ return loc;
362
372
}
363
373
364
- return loc;
374
+ return FusedLoc::get (loc->getContext (), flattenedLocs.takeVector (),
375
+ fusedLoc.getMetadata ());
365
376
}
366
377
367
378
void OperationFolder::appendFoldedLocation (Operation *retainedOp,
0 commit comments