-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Fold linalg.fill
-> linalg.copy
#72920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fold linalg.fill
-> linalg.copy
#72920
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (MaheshRavishankar) ChangesFull diff: https://github.com/llvm/llvm-project/pull/72920.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index d12ba8c4c59b33f..58af9995548e939 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,14 +803,36 @@ struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
}
};
+/// Fold fill with copy.
+struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
+ using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
+ PatternRewriter &rewriter) const override {
+ if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
+ rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
+ fillOp.getInputs(),
+ copyOp.getOutputs());
+ return success();
+ }
+ if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
+ rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
+ fillOp.getOutputs());
+ return success();
+ }
+ return failure();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldFillWithTensorExtract, FoldFillWithPack, FoldFillWithPad,
- FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
- FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
- FoldInsertPadIntoFill>(context);
+ results
+ .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
+ FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
+ FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
+ FoldInsertPadIntoFill>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 7793e435582746c..c054829a915d7ba 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -972,3 +972,16 @@ func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<
%3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %3: tensor<?x?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_fill_to_copy_dest(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK: linalg.copy ins(%[[ARG1]] : tensor<?x?xf32>) outs(%[[ARG0]] : tensor<?x?xf32>)
+func.func @canonicalize_fill_to_copy_dest(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0.0 : f32
+ %fill = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %copy = linalg.copy ins(%arg1 : tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %copy : tensor<?x?xf32>
+}
|
5ebb2a5
to
ca09afe
Compare
ca09afe
to
4d98245
Compare
linalg.fill
-> linalg.copy
along outs
use in the consumer.linalg.fill
-> linalg.copy
// CHECK-LABEL: func @canonicalize_fill_to_copy_input( | ||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32> | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>) | ||
// CHECK: %[[ZERO:.+]] = arith.constant 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't DCE get rid of linalg.copy
here?
// CHECK: %[[ZERO:.+]] = arith.constant 0.0 | |
// CHECK: %[[ZERO:.+]] = arith.constant 0.0 | |
// CHECK-NOT: linalg.copy |
Similar comment below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DCE does remove it. It isn't removed by the pattern though explicitly. So I'd rather not test for that here.
No description provided.