Skip to content

Fold linalg.fill -> linalg.copy #72920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2023

Conversation

MaheshRavishankar
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/72920.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+26-4)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d12ba8c4c59b33f..58af9995548e939 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,14 +803,36 @@ struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
   }
 };
 
+/// Fold fill with copy.
+struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
+  using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
+      rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
+                                          fillOp.getInputs(),
+                                          copyOp.getOutputs());
+      return success();
+    }
+    if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
+      rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
+                                                  fillOp.getOutputs());
+      return success();
+    }
+    return failure();
+  }
+};
+
 } // namespace
 
 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
-              FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
-              FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
-              FoldInsertPadIntoFill>(context);
+  results
+      .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
+           FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+           FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+           FoldInsertPadIntoFill>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7793e435582746c..c054829a915d7ba 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -972,3 +972,16 @@ func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<
   %3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
   return %3: tensor<?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_fill_to_copy_dest(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//       CHECK:   linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %copy : tensor<?x?xf32>
+}

@MaheshRavishankar MaheshRavishankar changed the title Fold linalg.fill -> linalg.copy along outs use in the consumer. Fold linalg.fill -> linalg.copy Nov 21, 2023
// CHECK-LABEL: func @canonicalize_fill_to_copy_input(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't DCE get rid of linalg.copy here?

Suggested change
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
// CHECK: %[[ZERO:.+]] = arith.constant 0.0
// CHECK-NOT: linalg.copy

Similar comment below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DCE does remove it. It isn't removed by the pattern though explicitly. So I'd rather not test for that here.

@nicolasvasilache nicolasvasilache self-requested a review November 21, 2023 11:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants