Skip to content

[mlir][Linalg] Add transform to convert linalg.copy into memref.copy #132422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025

Conversation

pabloantoniom
Copy link
Contributor

Targeted rewrite of a linalg.copy on memrefs to a memref.copy.

This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected.

Targeted rewrite of a linalg.copy on memrefs to a memref.copy.

This is useful when bufferizing copies to a linalg.copy, later applying
some transformations (for instance, tiling), and then rewriting the
copy into a memref.copy. If the input linalg.copy has different element
type in the source and destination, the transformation is rejected.
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Pablo Antonio Martinez (pabloantoniom)

Changes

Targeted rewrite of a linalg.copy on memrefs to a memref.copy.

This is useful when bufferizing copies to a linalg.copy, later applying some transformations (for instance, tiling), and then rewriting the copy into a memref.copy. If the input linalg.copy has different element type in the source and destination, the transformation is rejected.


Full diff: https://github.com/llvm/llvm-project/pull/132422.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+34)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+54)
  • (added) mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir (+70)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 12080cee85c9d..8406d170d882e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def LinalgCopyToMemrefOp :
+    Op<Transform_Dialect, "structured.linalg_copy_to_memref",
+      [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+       TransformEachOpTrait, TransformOpInterface]> {
+  let description = [{
+    Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
+    This is useful when bufferizing copies to a linalg.copy, later applying some
+    transformations, and then rewriting the copy into a memref.copy.
+    If the input has different element type in the source and destination,
+    the transformation is rejected.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` "
+                       "functional-type(operands, results) ";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>,
+  ];
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d65e7e4666c3..bfebb0fbbf938 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1176,6 +1176,60 @@ LogicalResult transform::InterchangeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LinalgCopyToMemrefOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *targetOp,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+
+  // Check if the target can be converted
+  if (!isa<linalg::CopyOp>(targetOp)) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "only linalg.copy target ops are supported";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
+  if (!copyOp.hasPureBufferSemantics()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError()
+        << "linalg.copy on tensors cannot be transformed into memref.copy";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  SmallVector<Value> inputs = copyOp.getInputs();
+  SmallVector<Value> outputs = copyOp.getOutputs();
+  assert(inputs.size() == 1 && "expected linalg copy op with one input");
+  assert(outputs.size() == 1 && "expected memref copy op with one output");
+  Value input = inputs.front();
+  Value output = outputs.front();
+
+  // linalg.copy supports different element types on source/dest whereas
+  // memref.copy does not, so we must check here that the types are the same,
+  // otherwise reject the transformation.
+  if (!dyn_cast<ShapedType>(input.getType()) ||
+      cast<ShapedType>(input.getType()).getElementType() !=
+          cast<ShapedType>(output.getType()).getElementType()) {
+    DiagnosedSilenceableFailure diag =
+        emitSilenceableError() << "linalg.copy with different source and "
+                                  "destination element types is not supported";
+    diag.attachNote(targetOp->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Target can be converted, do it.
+  auto memrefCopyOp =
+      rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
+
+  results.push_back(memrefCopyOp);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LowerPackOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
new file mode 100644
index 0000000000000..cd376ef1eb337
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-linalg-copy-to-memref.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+// CHECK:  func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
+// CHECK:    memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
+// CHECK:    return
+// CHECK:  }
+
+func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
+  linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
+  // expected-note @below {{target op}}
+  %0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
+  return %0 : tensor<128x64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+	  // expected-error @below {{linalg.copy on tensors cannot be transformed into memref.copy}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
+  // expected-note @below {{target op}}
+  linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+	  // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
+  // expected-note @below {{target op}}
+  linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+	  // expected-error @below {{linalg.copy with different source and destination element types is not supported}}
+    %1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing to the transform dialect, @pabloantoniom! IMO, code quality wise this is very good.

My understanding is that this runs after bufferization as it only applies to linalg-on-memref. As such it shouldn't have (unfortunate) interactions with bufferization. With that in mind, my understanding is that this is a valid (though very gradual) lowering step. It does raise the question for me: is there currently a pass that does this/something similar? If that's so, could you maybe mention it in the summary, i.e. commit message?

I am okay with this to go in with just my nits addressed, though if you could give it another couple days for others to have a look, that would be great. 👍

TransformEachOpTrait, TransformOpInterface]> {
let description = [{
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
This is useful when bufferizing copies to a linalg.copy, later applying some
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop "later"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@pabloantoniom
Copy link
Contributor Author

Thanks for contributing to the transform dialect, @pabloantoniom! IMO, code quality wise this is very good.

My understanding is that this runs after bufferization as it only applies to linalg-on-memref. As such it shouldn't have (unfortunate) interactions with bufferization. With that in mind, my understanding is that this is a valid (though very gradual) lowering step. It does raise the question for me: is there currently a pass that does this/something similar? If that's so, could you maybe mention it in the summary, i.e. commit message?

I am okay with this to go in with just my nits addressed, though if you could give it another couple days for others to have a look, that would be great. 👍

I have changed a couple of small things (mainly rewording) and added a new test. If you like the changes and nobody else has any objection I will merge this shortly. Thanks Rolf! 👍

Re your question: I'm not aware of something like this.

@pabloantoniom pabloantoniom requested a review from rolfmorel March 28, 2025 17:47
Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pabloantoniom pabloantoniom merged commit a338f80 into llvm:main Apr 1, 2025
11 checks passed
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request Apr 2, 2025
…lvm#132422)

Targeted rewrite of a linalg.copy on memrefs to a memref.copy.

This is useful when bufferizing copies to a linalg.copy, applying some
transformations, and then rewriting the copy into a memref.copy.
If the element types of the source and destination differ, or if the
source is a scalar, the transform produces a silenceable failure.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants