-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Linalg] Add transform to convert linalg.copy into memref.copy #132422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Targeted rewrite of a linalg.copy on memrefs to a memref.copy. This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Pablo Antonio Martinez (pabloantoniom) ChangesTargeted rewrite of a linalg.copy on memrefs to a memref.copy. This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected. Full diff: https://github.com/llvm/llvm-project/pull/132422.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..8406d170d882e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
}];
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+ Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformEachOpTrait, TransformOpInterface]> {
+ let description = [{
+ Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+ This is useful when bufferizing copies to a linalg.copy, later applying some
+ transformations, and then rewriting the copy into a memref.copy.
+ If the input has different element type in the source and destination,
+ the transformation is rejected.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target);
+ let results = (outs TransformHandleTypeInterface:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` "
+ "functional-type(operands, results) ";
+
+ let builders = [
+ OpBuilder<(ins "Value":$target)>,
+ ];
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..bfebb0fbbf938 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,60 @@ LogicalResult transform::InterchangeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *targetOp,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+
+ // Check if the target can be converted
+ if (!isa<linalg::CopyOp>(targetOp)) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "only linalg.copy target ops are supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+ if (!copyOp.hasPureBufferSemantics()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError()
+ << "linalg.copy on tensors cannot be transformed into memref.copy";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<Value> inputs = copyOp.getInputs();
+ SmallVector<Value> outputs = copyOp.getOutputs();
+ assert(inputs.size() == 1 && "expected linalg copy op with one input");
+ assert(outputs.size() == 1 && "expected memref copy op with one output");
+ Value input = inputs.front();
+ Value output = outputs.front();
+
+ // linalg.copy supports different element types on source/dest whereas
+ // memref.copy does not, so we must check here that the types are the same,
+ // otherwise reject the transformation.
+ if (!dyn_cast<ShapedType>(input.getType()) ||
+ cast<ShapedType>(input.getType()).getElementType() !=
+ cast<ShapedType>(output.getType()).getElementType()) {
+ DiagnosedSilenceableFailure diag =
+ emitSilenceableError() << "linalg.copy with different source and "
+ "destination element types is not supported";
+ diag.attachNote(targetOp->getLoc()) << "target op";
+ return diag;
+ }
+
+ // Target can be converted, do it.
+ auto memrefCopyOp =
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+ results.push_back(memrefCopyOp);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LowerPackOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..cd376ef1eb337
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK: func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK: memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK: return
+// CHECK: }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+ // expected-note @below {{target op}}
+ %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+ return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy on tensors cannot be transformed into memref.copy}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+ // expected-note @below {{target op}}
+ linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+ %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing to the transform dialect, @pabloantoniom! IMO, code quality wise this is very good.
My understanding is that this runs after bufferization as it only applies to linalg-on-memref. As such it shouldn't have (unfortunate) interactions with bufferization. With that in mind, my understanding is that this is a valid (though very gradual) lowering step. It does raise the question for me: is there currently a pass that does this/something similar? If that's so, could you maybe mention it in the summary, i.e. commit message?
I am okay with this to go in with just my nits addressed, though if you could give it another couple days for others to have a look, that would be great. 👍
TransformEachOpTrait, TransformOpInterface]> { | ||
let description = [{ | ||
Targeted rewrite of a linalg.copy on memrefs to a memref.copy. | ||
This is useful when bufferizing copies to a linalg.copy, later applying some |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: drop "later"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
I have changed a couple of small things (mainly rewording) and added a new test. If you like the changes and nobody else has any objection I will merge this shortly. Thanks Rolf! 👍 Re your question: I'm not aware of something like this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…lvm#132422) Targeted rewrite of a linalg.copy on memrefs to a memref.copy. This is useful when bufferizing copies to a linalg.copy, applying some transformations, and then rewriting the copy into a memref.copy. If the element types of the source and destination differ, or if the source is a scalar, the transform produces a silenceable failure.
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected.