Skip to content

Revert "[MLIR] Fuse locations of merged constants (#74670)" #75381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

MaskRay
Copy link
Member

@MaskRay MaskRay commented Dec 13, 2023

This reverts commit 87e2e89
and its follow-up 0d1490f (#75218).

We observed significant OOM/timeout issues due to #74670 to quite a few
services including google-research/swirl-lm. The follow-up 75218 does
not address the issue. Perhaps this is worth more investigation.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Dec 13, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 13, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Fangrui Song (MaskRay)

Changes

This reverts commit 87e2e89
and its follow-up 0d1490f (#75218).

We observed significant OOM/timeout issues due to #74670 to quite a few
services including google-research/swirl-lm. The follow-up 75218 does
not address the issue. Perhaps this is worth more investigation.


Full diff: https://github.com/llvm/llvm-project/pull/75381.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Transforms/FoldUtils.h (+1-11)
  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+1-67)
  • (removed) mlir/test/Transforms/canonicalize-debuginfo.mlir (-41)
  • (removed) mlir/test/Transforms/constant-fold-debuginfo.mlir (-34)
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 28fa18cf942de4..2600da361496cd 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -33,8 +33,7 @@ class Value;
 class OperationFolder {
 public:
   OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
-      : fusedLocationTag(StringAttr::get(ctx, "CSE")), interfaces(ctx),
-        rewriter(ctx, listener) {}
+      : interfaces(ctx), rewriter(ctx, listener) {}
 
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, replaces `op`'s uses with
@@ -96,15 +95,6 @@ class OperationFolder {
                                     Dialect *dialect, Attribute value,
                                     Type type, Location loc);
 
-  // Fuse `foldedLocation` into the Location of `retainedOp`. This will result
-  // in `retainedOp` having a FusedLoc with `fusedLocationTag` to help trace the
-  // source of the fusion. If `retainedOp` already had a FusedLoc with the same
-  // tag, `foldedLocation` will simply be appended to it.
-  void appendFoldedLocation(Operation *retainedOp, Location foldedLocation);
-
-  /// Tag for annotating fused locations as a result of merging constants.
-  StringAttr fusedLocationTag;
-
   /// A mapping between an insertion region and the constants that have been
   /// created within it.
   DenseMap<Region *, ConstantMap> foldScopes;
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..90ee5ba51de3ad 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -141,7 +141,6 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
   // If there is an existing constant, replace `op`.
   if (folderConstOp) {
     notifyRemoval(op);
-    appendFoldedLocation(folderConstOp, op->getLoc());
     rewriter.replaceOp(op, folderConstOp->getResults());
     return false;
   }
@@ -295,10 +294,8 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
   // Check if an existing mapping already exists.
   auto constKey = std::make_tuple(dialect, value, type);
   Operation *&constOp = uniquedConstants[constKey];
-  if (constOp) {
-    appendFoldedLocation(constOp, loc);
+  if (constOp)
     return constOp;
-  }
 
   // If one doesn't exist, try to materialize one.
   if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
@@ -319,7 +316,6 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
   // materialized operation in favor of the existing one.
   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
     notifyRemoval(constOp);
-    appendFoldedLocation(existingOp, constOp->getLoc());
     rewriter.eraseOp(constOp);
     referencedDialects[existingOp].push_back(dialect);
     return constOp = existingOp;
@@ -330,65 +326,3 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
   auto newIt = uniquedConstants.insert({newKey, constOp});
   return newIt.first->second;
 }
-
-/// Helper that flattens nested fused locations to a single fused location.
-/// Fused locations nested under non-fused locations are not flattened, and
-/// calling this on non-fused locations is a no-op as a result.
-///
-/// Fused locations are only flattened into parent fused locations if the
-/// child fused location has no metadata, or if the metadata of the parent and
-/// child fused locations are the same---this to avoid breaking cases where
-/// metadata matter.
-static Location FlattenFusedLocationRecursively(const Location loc) {
-  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
-    SetVector<Location> flattenedLocs;
-    Attribute metadata = fusedLoc.getMetadata();
-
-    for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
-      Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
-      auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
-      if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
-                                flattenedFusedLoc.getMetadata() == metadata)) {
-        ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
-        flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
-      } else {
-        flattenedLocs.insert(flattenedLoc);
-      }
-    }
-
-    return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
-                         fusedLoc.getMetadata());
-  }
-
-  return loc;
-}
-
-void OperationFolder::appendFoldedLocation(Operation *retainedOp,
-                                           Location foldedLocation) {
-  // Append into existing fused location if it has the same tag.
-  if (auto existingFusedLoc =
-          dyn_cast<FusedLocWith<StringAttr>>(retainedOp->getLoc())) {
-    StringAttr existingMetadata = existingFusedLoc.getMetadata();
-    if (existingMetadata == fusedLocationTag) {
-      ArrayRef<Location> existingLocations = existingFusedLoc.getLocations();
-      SetVector<Location> locations(existingLocations.begin(),
-                                    existingLocations.end());
-      locations.insert(foldedLocation);
-      Location newFusedLoc = FusedLoc::get(
-          retainedOp->getContext(), locations.takeVector(), existingMetadata);
-      retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
-      return;
-    }
-  }
-
-  // Create a new fusedloc with retainedOp's loc and foldedLocation.
-  // If they're already equal, no need to fuse.
-  if (retainedOp->getLoc() == foldedLocation)
-    return;
-
-  Location newFusedLoc =
-      FusedLoc::get(retainedOp->getContext(),
-                    {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
-  retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
-}
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
deleted file mode 100644
index 217cc29c0095e2..00000000000000
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ /dev/null
@@ -1,41 +0,0 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
-
-// CHECK-LABEL: func @merge_constants
-func.func @merge_constants() -> (index, index, index, index, index, index, index) {
-  // CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
-  %0 = arith.constant 42 : index loc("merge_constants":0:0)
-  %1 = arith.constant 42 : index loc("merge_constants":1:0)
-  %2 = arith.constant 42 : index loc("merge_constants":2:0)
-  %3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
-  %4 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
-  %5 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
-  %6 = arith.constant 43 : index loc(fused<"some_other_label">["merge_constants":3:0])
-  return %0, %1, %2, %3, %4, %5, %6 : index, index, index, index, index, index, index
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
-// CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
-// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
-// CHECK-DAG: #[[FusedLoc_CSE_1:.*]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
-// CHECK-DAG: #[[FusedLoc_Some_Label:.*]] = loc(fused<"some_label">[#[[LocConst3]]])
-// CHECK-DAG: #[[FusedLoc_Some_Other_Label:.*]] = loc(fused<"some_other_label">[#[[LocConst3]]])
-// CHECK-DAG: #[[FusedLoc_CSE_2:.*]] = loc(fused<"CSE">[#[[FusedLoc_Some_Label]], #[[FusedLoc_Some_Other_Label]]])
-
-// -----
-
-// CHECK-LABEL: func @hoist_constant
-func.func @hoist_constant(%arg0: memref<8xi32>) {
-  // CHECK-NEXT: arith.constant 42 : i32 loc(#[[FusedLoc:.*]])
-  affine.for %arg1 = 0 to 8 {
-    %0 = arith.constant 42 : i32 loc("hoist_constant":0:0)
-    %1 = arith.constant 42 : i32 loc("hoist_constant":1:0)
-    memref.store %0, %arg0[%arg1] : memref<8xi32>
-    memref.store %1, %arg0[%arg1] : memref<8xi32>
-  }
-  return
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]]])
diff --git a/mlir/test/Transforms/constant-fold-debuginfo.mlir b/mlir/test/Transforms/constant-fold-debuginfo.mlir
deleted file mode 100644
index 79a25f860a4841..00000000000000
--- a/mlir/test/Transforms/constant-fold-debuginfo.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-constant-fold -mlir-print-debuginfo | FileCheck %s
-
-// CHECK-LABEL: func @fold_and_merge
-func.func @fold_and_merge() -> (i32, i32) {
-  %0 = arith.constant 1 : i32
-  %1 = arith.constant 5 : i32
-
-  // CHECK-NEXT: [[C:%.+]] = arith.constant 6 : i32 loc(#[[FusedLoc:.*]])
-  %2 = arith.addi %0, %1 : i32 loc("fold_and_merge":0:0)
-
-  %3 = arith.constant 6 : i32 loc("fold_and_merge":1:0)
-
-  return %2, %3: i32, i32
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("fold_and_merge":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("fold_and_merge":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
-
-// -----
-
-// CHECK-LABEL: func @materialize_different_dialect
-func.func @materialize_different_dialect() -> (f32, f32) {
-  // CHECK: arith.constant 1.{{0*}}e+00 : f32 loc(#[[FusedLoc:.*]])
-  %0 = arith.constant -1.0 : f32
-  %1 = math.absf %0 : f32 loc("materialize_different_dialect":0:0)
-  %2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":1:0)
-
-  return %1, %2: f32, f32
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_different_dialect":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("materialize_different_dialect":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])

@MaskRay MaskRay requested review from zyx-billy, bchetioui and joker-eph and removed request for zyx-billy and bchetioui December 13, 2023 21:14
This reverts commit 87e2e89.
and its follow-ups 0d1490f (llvm#75218)
and 6fe3cd5 (llvm#75312).

We observed significant OOM/timeout issues due to llvm#74670 to quite a few
services including google-research/swirl-lm. The follow-up llvm#75218 and
 llvm#75312 do not address the issue. Perhaps this is worth more
investigation.
Copy link
Contributor

@zyx-billy zyx-billy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is causing significant issues with OOM / timeout, I'm open to reverting for now and finding a more efficient approach (I'm proposing one in #75258).

Do you have any data that suggests why the previous fix (#75218) didn't solve the problem? Was it generating super long fused locations? Or was it just due to all the temporary fused locations that were created since we "append" locations incrementally?

@MaskRay
Copy link
Member Author

MaskRay commented Dec 13, 2023

If this is causing significant issues with OOM / timeout, I'm open to reverting for now and finding a more efficient approach (I'm proposing one in #75258).

Do you have any data that suggests why the previous fix (#75218) didn't solve the problem? Was it generating super long fused locations? Or was it just due to all the temporary fused locations that were created since we "append" locations incrementally?

Thank you for being open for reverts! I don't know the technical aspect of the issue... but I'll find right folks help investigate this issue. Stay tuned!

@MaskRay
Copy link
Member Author

MaskRay commented Dec 13, 2023

Thanks for the approval. Sorry that I manually pushed it as 2a9d8ca

@MaskRay MaskRay closed this Dec 13, 2023
@MaskRay MaskRay deleted the mlir-fold branch December 13, 2023 22:31
@yfcyfcyfc
Copy link

yfcyfcyfc commented Dec 14, 2023

Note the breakages are from several tests from SWIRL-LM. These tests are not currently only kept internally -- but we can expose some if that helps with the debugging.

(Understanding the pattern that exploded would be useful though)

@joker-eph
Copy link
Collaborator

We'd need someone downstream to actually work on narrowing down the issue without upstream having to build TensorFlow and learn how to debug within TensorFlow.
(MLIR has support for "crash reproducer" which could be levered to narrow this down to a single-pass reproducer)

@jpienaar
Copy link
Member

@yfcyfcyfc could you ping me internally? I could try and help to narrow down.

@joker-eph
Copy link
Collaborator

FYI: we are going with a different approach, there is no plan to re-land this feature. So no urgency to help debugging this just now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants