Skip to content

[mlir][Linalg] Avoid doing op replacement in linalg::dropUnitDims. #105749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

MaheshRavishankar
Copy link
Contributor

It is better to do the replacement in the caller. This avoids the footgun if the caller needs the original operation. Instead return the produced operation and replacement values.

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

It is better to do the replacement in the caller. This avoids the footgun if the caller needs the original operation. Instead return the produced operation and replacement values.


Full diff: https://github.com/llvm/llvm-project/pull/105749.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+7-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+8-4)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp (+6-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index bee3452ebb685f..0208f854f799ec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -488,8 +488,13 @@ struct ControlDropUnitDims {
     return SmallVector<unsigned>{};
   };
 };
-LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
-                           const ControlDropUnitDims &options);
+struct DropUnitDimsResult {
+  linalg::GenericOp resultOp;
+  SmallVector<Value> replacements;
+};
+FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter,
+                                           GenericOp genericOp,
+                                           const ControlDropUnitDims &options);
 
 /// Fuse two `linalg.generic` operations that have a producer-consumer
 /// relationship captured through `fusedOperand`. The method expects
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 36f8696bf1b274..38bb3c02d30fa8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -386,7 +386,7 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(
   return info;
 }
 
-LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
+FailureOr<DropUnitDimsResult> linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
                                    const ControlDropUnitDims &options) {
   SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
   if (indexingMaps.empty())
@@ -545,8 +545,7 @@ LogicalResult linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
     resultReplacements.push_back(expandedValue);
   }
 
-  rewriter.replaceOp(genericOp, resultReplacements);
-  return success();
+  return DropUnitDimsResult{replacementOp, resultReplacements};
 }
 
 namespace {
@@ -557,7 +556,12 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    return dropUnitDims(rewriter, genericOp, options);
+    FailureOr<DropUnitDimsResult> result = dropUnitDims(rewriter, genericOp, options);
+    if (failed(result)) {
+      return failure();
+    }
+    rewriter.replaceOp(genericOp, result->replacements);
+    return success();
   }
 
 private:
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
index 85a6d5f9d9215c..bd43601258bdfc 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDropUnitDims.cpp
@@ -25,7 +25,12 @@ LogicalResult dropOutermostUnitDims(RewriterBase &rewriter,
                                     linalg::GenericOp genericOp) {
   linalg::ControlDropUnitDims options;
   options.controlFn = [](Operation *op) { return SmallVector<unsigned>{0}; };
-  return linalg::dropUnitDims(rewriter, genericOp, options);
+  FailureOr<linalg::DropUnitDimsResult> result = linalg::dropUnitDims(rewriter, genericOp, options);
+  if (failed(result)) {
+    return failure();
+  }
+  rewriter.replaceOp(genericOp, result->replacements);
+  return success();
 }
 
 struct TestLinalgDropUnitDims

Copy link

github-actions bot commented Aug 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Aug 22, 2024
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Aug 23, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/drop_unit_dim_result_handling branch 2 times, most recently from b0001be to 749bc95 Compare August 23, 2024 17:22
It is better to do the replacement in the caller. This avoids the
footgun if the caller needs the original operation. Instead return the
produced operation and replacement values.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/drop_unit_dim_result_handling branch from 749bc95 to 5357ecd Compare August 23, 2024 19:12
@@ -488,8 +488,13 @@ struct ControlDropUnitDims {
return SmallVector<unsigned>{};
};
};
LogicalResult dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
const ControlDropUnitDims &options);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short documentation comment on this struct + function?

@MaheshRavishankar MaheshRavishankar merged commit 4dbaef6 into llvm:main Aug 23, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants