-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Revert "[MLIR] Fuse locations of merged constants (#74670)" #75381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Fangrui Song (MaskRay) ChangesThis reverts commit 87e2e89 We observed significant OOM/timeout issues due to #74670 to quite a few Full diff: https://github.com/llvm/llvm-project/pull/75381.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 28fa18cf942de4..2600da361496cd 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -33,8 +33,7 @@ class Value;
class OperationFolder {
public:
OperationFolder(MLIRContext *ctx, OpBuilder::Listener *listener = nullptr)
- : fusedLocationTag(StringAttr::get(ctx, "CSE")), interfaces(ctx),
- rewriter(ctx, listener) {}
+ : interfaces(ctx), rewriter(ctx, listener) {}
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, replaces `op`'s uses with
@@ -96,15 +95,6 @@ class OperationFolder {
Dialect *dialect, Attribute value,
Type type, Location loc);
- // Fuse `foldedLocation` into the Location of `retainedOp`. This will result
- // in `retainedOp` having a FusedLoc with `fusedLocationTag` to help trace the
- // source of the fusion. If `retainedOp` already had a FusedLoc with the same
- // tag, `foldedLocation` will simply be appended to it.
- void appendFoldedLocation(Operation *retainedOp, Location foldedLocation);
-
- /// Tag for annotating fused locations as a result of merging constants.
- StringAttr fusedLocationTag;
-
/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes;
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 056a681718e121..90ee5ba51de3ad 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -141,7 +141,6 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
// If there is an existing constant, replace `op`.
if (folderConstOp) {
notifyRemoval(op);
- appendFoldedLocation(folderConstOp, op->getLoc());
rewriter.replaceOp(op, folderConstOp->getResults());
return false;
}
@@ -295,10 +294,8 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
// Check if an existing mapping already exists.
auto constKey = std::make_tuple(dialect, value, type);
Operation *&constOp = uniquedConstants[constKey];
- if (constOp) {
- appendFoldedLocation(constOp, loc);
+ if (constOp)
return constOp;
- }
// If one doesn't exist, try to materialize one.
if (!(constOp = materializeConstant(dialect, rewriter, value, type, loc)))
@@ -319,7 +316,6 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
// materialized operation in favor of the existing one.
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
notifyRemoval(constOp);
- appendFoldedLocation(existingOp, constOp->getLoc());
rewriter.eraseOp(constOp);
referencedDialects[existingOp].push_back(dialect);
return constOp = existingOp;
@@ -330,65 +326,3 @@ OperationFolder::tryGetOrCreateConstant(ConstantMap &uniquedConstants,
auto newIt = uniquedConstants.insert({newKey, constOp});
return newIt.first->second;
}
-
-/// Helper that flattens nested fused locations to a single fused location.
-/// Fused locations nested under non-fused locations are not flattened, and
-/// calling this on non-fused locations is a no-op as a result.
-///
-/// Fused locations are only flattened into parent fused locations if the
-/// child fused location has no metadata, or if the metadata of the parent and
-/// child fused locations are the same---this to avoid breaking cases where
-/// metadata matter.
-static Location FlattenFusedLocationRecursively(const Location loc) {
- if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
- SetVector<Location> flattenedLocs;
- Attribute metadata = fusedLoc.getMetadata();
-
- for (const Location &unflattenedLoc : fusedLoc.getLocations()) {
- Location flattenedLoc = FlattenFusedLocationRecursively(unflattenedLoc);
- auto flattenedFusedLoc = dyn_cast<FusedLoc>(flattenedLoc);
-
- if (flattenedFusedLoc && (!flattenedFusedLoc.getMetadata() ||
- flattenedFusedLoc.getMetadata() == metadata)) {
- ArrayRef<Location> nestedLocations = flattenedFusedLoc.getLocations();
- flattenedLocs.insert(nestedLocations.begin(), nestedLocations.end());
- } else {
- flattenedLocs.insert(flattenedLoc);
- }
- }
-
- return FusedLoc::get(loc->getContext(), flattenedLocs.takeVector(),
- fusedLoc.getMetadata());
- }
-
- return loc;
-}
-
-void OperationFolder::appendFoldedLocation(Operation *retainedOp,
- Location foldedLocation) {
- // Append into existing fused location if it has the same tag.
- if (auto existingFusedLoc =
- dyn_cast<FusedLocWith<StringAttr>>(retainedOp->getLoc())) {
- StringAttr existingMetadata = existingFusedLoc.getMetadata();
- if (existingMetadata == fusedLocationTag) {
- ArrayRef<Location> existingLocations = existingFusedLoc.getLocations();
- SetVector<Location> locations(existingLocations.begin(),
- existingLocations.end());
- locations.insert(foldedLocation);
- Location newFusedLoc = FusedLoc::get(
- retainedOp->getContext(), locations.takeVector(), existingMetadata);
- retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
- return;
- }
- }
-
- // Create a new fusedloc with retainedOp's loc and foldedLocation.
- // If they're already equal, no need to fuse.
- if (retainedOp->getLoc() == foldedLocation)
- return;
-
- Location newFusedLoc =
- FusedLoc::get(retainedOp->getContext(),
- {retainedOp->getLoc(), foldedLocation}, fusedLocationTag);
- retainedOp->setLoc(FlattenFusedLocationRecursively(newFusedLoc));
-}
diff --git a/mlir/test/Transforms/canonicalize-debuginfo.mlir b/mlir/test/Transforms/canonicalize-debuginfo.mlir
deleted file mode 100644
index 217cc29c0095e2..00000000000000
--- a/mlir/test/Transforms/canonicalize-debuginfo.mlir
+++ /dev/null
@@ -1,41 +0,0 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file -mlir-print-debuginfo | FileCheck %s
-
-// CHECK-LABEL: func @merge_constants
-func.func @merge_constants() -> (index, index, index, index, index, index, index) {
- // CHECK-NEXT: arith.constant 42 : index loc(#[[FusedLoc:.*]])
- %0 = arith.constant 42 : index loc("merge_constants":0:0)
- %1 = arith.constant 42 : index loc("merge_constants":1:0)
- %2 = arith.constant 42 : index loc("merge_constants":2:0)
- %3 = arith.constant 42 : index loc("merge_constants":2:0) // repeated loc
- %4 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
- %5 = arith.constant 43 : index loc(fused<"some_label">["merge_constants":3:0])
- %6 = arith.constant 43 : index loc(fused<"some_other_label">["merge_constants":3:0])
- return %0, %1, %2, %3, %4, %5, %6 : index, index, index, index, index, index, index
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("merge_constants":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("merge_constants":1:0)
-// CHECK-DAG: #[[LocConst2:.*]] = loc("merge_constants":2:0)
-// CHECK-DAG: #[[LocConst3:.*]] = loc("merge_constants":3:0)
-// CHECK-DAG: #[[FusedLoc_CSE_1:.*]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]], #[[LocConst2]]])
-// CHECK-DAG: #[[FusedLoc_Some_Label:.*]] = loc(fused<"some_label">[#[[LocConst3]]])
-// CHECK-DAG: #[[FusedLoc_Some_Other_Label:.*]] = loc(fused<"some_other_label">[#[[LocConst3]]])
-// CHECK-DAG: #[[FusedLoc_CSE_2:.*]] = loc(fused<"CSE">[#[[FusedLoc_Some_Label]], #[[FusedLoc_Some_Other_Label]]])
-
-// -----
-
-// CHECK-LABEL: func @hoist_constant
-func.func @hoist_constant(%arg0: memref<8xi32>) {
- // CHECK-NEXT: arith.constant 42 : i32 loc(#[[FusedLoc:.*]])
- affine.for %arg1 = 0 to 8 {
- %0 = arith.constant 42 : i32 loc("hoist_constant":0:0)
- %1 = arith.constant 42 : i32 loc("hoist_constant":1:0)
- memref.store %0, %arg0[%arg1] : memref<8xi32>
- memref.store %1, %arg0[%arg1] : memref<8xi32>
- }
- return
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("hoist_constant":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("hoist_constant":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst0]], #[[LocConst1]]])
diff --git a/mlir/test/Transforms/constant-fold-debuginfo.mlir b/mlir/test/Transforms/constant-fold-debuginfo.mlir
deleted file mode 100644
index 79a25f860a4841..00000000000000
--- a/mlir/test/Transforms/constant-fold-debuginfo.mlir
+++ /dev/null
@@ -1,34 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-constant-fold -mlir-print-debuginfo | FileCheck %s
-
-// CHECK-LABEL: func @fold_and_merge
-func.func @fold_and_merge() -> (i32, i32) {
- %0 = arith.constant 1 : i32
- %1 = arith.constant 5 : i32
-
- // CHECK-NEXT: [[C:%.+]] = arith.constant 6 : i32 loc(#[[FusedLoc:.*]])
- %2 = arith.addi %0, %1 : i32 loc("fold_and_merge":0:0)
-
- %3 = arith.constant 6 : i32 loc("fold_and_merge":1:0)
-
- return %2, %3: i32, i32
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("fold_and_merge":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("fold_and_merge":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
-
-// -----
-
-// CHECK-LABEL: func @materialize_different_dialect
-func.func @materialize_different_dialect() -> (f32, f32) {
- // CHECK: arith.constant 1.{{0*}}e+00 : f32 loc(#[[FusedLoc:.*]])
- %0 = arith.constant -1.0 : f32
- %1 = math.absf %0 : f32 loc("materialize_different_dialect":0:0)
- %2 = arith.constant 1.0 : f32 loc("materialize_different_dialect":1:0)
-
- return %1, %2: f32, f32
-}
-
-// CHECK-DAG: #[[LocConst0:.*]] = loc("materialize_different_dialect":0:0)
-// CHECK-DAG: #[[LocConst1:.*]] = loc("materialize_different_dialect":1:0)
-// CHECK: #[[FusedLoc]] = loc(fused<"CSE">[#[[LocConst1]], #[[LocConst0]]])
|
This reverts commit 87e2e89. and its follow-ups 0d1490f (llvm#75218) and 6fe3cd5 (llvm#75312). We observed significant OOM/timeout issues due to llvm#74670 to quite a few services including google-research/swirl-lm. The follow-up llvm#75218 and llvm#75312 do not address the issue. Perhaps this is worth more investigation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is causing significant issues with OOM / timeout, I'm open to reverting for now and finding a more efficient approach (I'm proposing one in #75258).
Do you have any data that suggests why the previous fix (#75218) didn't solve the problem? Was it generating super long fused locations? Or was it just due to all the temporary fused locations that were created since we "append" locations incrementally?
Thank you for being open for reverts! I don't know the technical aspect of the issue... but I'll find right folks help investigate this issue. Stay tuned! |
Thanks for the approval. Sorry that I manually pushed it as 2a9d8ca |
Note the breakages are from several tests from SWIRL-LM. These tests are not currently only kept internally -- but we can expose some if that helps with the debugging. (Understanding the pattern that exploded would be useful though) |
We'd need someone downstream to actually work on narrowing down the issue without upstream having to build TensorFlow and learn how to debug within TensorFlow. |
@yfcyfcyfc could you ping me internally? I could try and help to narrow down. |
FYI: we are going with a different approach, there is no plan to re-land this feature. So no urgency to help debugging this just now. |
This reverts commit 87e2e89
and its follow-up 0d1490f (#75218).
We observed significant OOM/timeout issues due to #74670 to quite a few
services including google-research/swirl-lm. The follow-up 75218 does
not address the issue. Perhaps this is worth more investigation.