Skip to content

Commit a338f80

Browse files
[mlir][Linalg] Add transform to convert linalg.copy into memref.copy (#132422)
Targeted rewrite of a linalg.copy on memrefs to a memref.copy. This is useful when bufferizing copies to a linalg.copy, applying some transformations, and then rewriting the copy into a memref.copy. If the element types of the source and destination differ, or if the source is a scalar, the transform produces a silenceable failure.
1 parent 6892d54 commit a338f80

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,40 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
559559
}];
560560
}
561561

562+
//===----------------------------------------------------------------------===//
563+
// LinalgCopyToMemrefOp
564+
//===----------------------------------------------------------------------===//
565+
566+
def LinalgCopyToMemrefOp :
567+
Op<Transform_Dialect, "structured.linalg_copy_to_memref",
568+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
569+
TransformEachOpTrait, TransformOpInterface]> {
570+
let description = [{
571+
Targeted rewrite of a linalg.copy on memrefs to a memref.copy.
572+
This is useful when bufferizing copies to a linalg.copy, later applying some
573+
transformations, and then rewriting the copy into a memref.copy.
574+
If the element types of the source and destination differ, or if the source
575+
is a scalar, the transform produces a silenceable failure.
576+
}];
577+
578+
let arguments = (ins TransformHandleTypeInterface:$target);
579+
let results = (outs TransformHandleTypeInterface:$transformed);
580+
581+
let assemblyFormat = "$target attr-dict `:` "
582+
"functional-type(operands, results) ";
583+
584+
let builders = [
585+
OpBuilder<(ins "Value":$target)>,
586+
];
587+
let extraClassDeclaration = [{
588+
::mlir::DiagnosedSilenceableFailure applyToOne(
589+
::mlir::transform::TransformRewriter &rewriter,
590+
::mlir::Operation *target,
591+
::mlir::transform::ApplyToEachResultList &results,
592+
::mlir::transform::TransformState &state);
593+
}];
594+
}
595+
562596
//===----------------------------------------------------------------------===//
563597
// LowerPackOp
564598
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,71 @@ LogicalResult transform::InterchangeOp::verify() {
11761176
return success();
11771177
}
11781178

1179+
//===----------------------------------------------------------------------===//
1180+
// LinalgCopyToMemrefOp
1181+
//===----------------------------------------------------------------------===//
1182+
1183+
DiagnosedSilenceableFailure transform::LinalgCopyToMemrefOp::applyToOne(
1184+
transform::TransformRewriter &rewriter, Operation *targetOp,
1185+
transform::ApplyToEachResultList &results,
1186+
transform::TransformState &state) {
1187+
1188+
// Check if the target can be converted.
1189+
if (!isa<linalg::CopyOp>(targetOp)) {
1190+
DiagnosedSilenceableFailure diag =
1191+
emitSilenceableError() << "only linalg.copy target ops are supported";
1192+
diag.attachNote(targetOp->getLoc()) << "target op";
1193+
return diag;
1194+
}
1195+
1196+
auto copyOp = dyn_cast<linalg::CopyOp>(targetOp);
1197+
if (!copyOp.hasPureBufferSemantics()) {
1198+
DiagnosedSilenceableFailure diag =
1199+
emitSilenceableError()
1200+
<< "cannot transform a linalg.copy on tensors into a memref.copy";
1201+
diag.attachNote(targetOp->getLoc()) << "target op";
1202+
return diag;
1203+
}
1204+
1205+
SmallVector<Value> inputs = copyOp.getInputs();
1206+
SmallVector<Value> outputs = copyOp.getOutputs();
1207+
assert(inputs.size() == 1 && "expected linalg copy op with one input");
1208+
assert(outputs.size() == 1 && "expected memref copy op with one output");
1209+
Value input = inputs.front();
1210+
Value output = outputs.front();
1211+
1212+
// linalg.copy supports different element types on source/dest whereas
1213+
// memref.copy does not, so we must check that the source and dest types can
1214+
// be handled by memref.copy and otherwise reject the transformation.
1215+
if (!dyn_cast<ShapedType>(input.getType())) {
1216+
DiagnosedSilenceableFailure diag =
1217+
emitSilenceableError()
1218+
<< "cannot transform a linalg.copy which input has no shape";
1219+
diag.attachNote(targetOp->getLoc()) << "target op";
1220+
return diag;
1221+
}
1222+
1223+
// linalg.copy destination must be a shaped type.
1224+
assert(dyn_cast<ShapedType>(output.getType()));
1225+
1226+
if (cast<ShapedType>(input.getType()).getElementType() !=
1227+
cast<ShapedType>(output.getType()).getElementType()) {
1228+
DiagnosedSilenceableFailure diag =
1229+
emitSilenceableError()
1230+
<< "cannot transform a linalg.copy with different source and "
1231+
"destination element types ";
1232+
diag.attachNote(targetOp->getLoc()) << "target op";
1233+
return diag;
1234+
}
1235+
1236+
// Target can be converted, do it.
1237+
auto memrefCopyOp =
1238+
rewriter.replaceOpWithNewOp<memref::CopyOp>(targetOp, input, output);
1239+
1240+
results.push_back(memrefCopyOp);
1241+
return DiagnosedSilenceableFailure::success();
1242+
}
1243+
11791244
//===----------------------------------------------------------------------===//
11801245
// LowerPackOp
11811246
//===----------------------------------------------------------------------===//
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// RUN: mlir-opt -transform-interpreter %s --split-input-file --allow-unregistered-dialect -verify-diagnostics | FileCheck %s
2+
3+
// CHECK: func.func @linalg_copy_to_memref_copy(%[[INPUT:.*]]: memref<128x64xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
4+
// CHECK: memref.copy %[[INPUT]], %[[OUTPUT]] : memref<128x64xf32> to memref<128x64xf32>
5+
// CHECK: return
6+
// CHECK: }
7+
8+
func.func @linalg_copy_to_memref_copy(%input : memref<128x64xf32>, %output : memref<128x64xf32>) {
9+
linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf32>)
10+
return
11+
}
12+
13+
module attributes {transform.with_named_sequence} {
14+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
15+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
16+
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
17+
transform.yield
18+
}
19+
}
20+
21+
// -----
22+
23+
// CHECK: func.func @linalg_copy_to_memref_copy_strides(%[[INPUT:.*]]: memref<128x32xf32>, %[[OUTPUT:.*]]: memref<128x64xf32>) {
24+
// CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
25+
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
26+
// CHECK: memref.copy %[[INPUT]], %[[SUBVIEW]] : memref<128x32xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
27+
// CHECK: return
28+
// CHECK: }
29+
30+
func.func @linalg_copy_to_memref_copy_strides(%input : memref<128x32xf32>, %output : memref<128x64xf32>) {
31+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<128x64xf32>
32+
%subview = memref.subview %alloc[0, 32] [128, 32] [1, 1] : memref<128x64xf32> to memref<128x32xf32, strided<[64, 1], offset: 32>>
33+
linalg.copy ins(%input : memref<128x32xf32>) outs(%subview : memref<128x32xf32, strided<[64, 1], offset: 32>>)
34+
return
35+
}
36+
37+
module attributes {transform.with_named_sequence} {
38+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
39+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
40+
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
41+
transform.yield
42+
}
43+
}
44+
45+
// -----
46+
47+
func.func @linalg_copy_to_memref_copy_tensors(%input : tensor<128x64xf32>, %output : tensor<128x64xf32>) -> tensor<128x64xf32> {
48+
// expected-note @below {{target op}}
49+
%0 = linalg.copy ins(%input : tensor<128x64xf32>) outs(%output : tensor<128x64xf32>) -> tensor<128x64xf32>
50+
return %0 : tensor<128x64xf32>
51+
}
52+
53+
module attributes {transform.with_named_sequence} {
54+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56+
// expected-error @below {{cannot transform a linalg.copy on tensors into a memref.copy}}
57+
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
58+
transform.yield
59+
}
60+
}
61+
62+
// -----
63+
64+
func.func @linalg_copy_to_memref_copy_different_element(%input : memref<128x64xf32>, %output : memref<128x64xf64>) {
65+
// expected-note @below {{target op}}
66+
linalg.copy ins(%input : memref<128x64xf32>) outs(%output : memref<128x64xf64>)
67+
return
68+
}
69+
70+
module attributes {transform.with_named_sequence} {
71+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
72+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
73+
// expected-error @below {{cannot transform a linalg.copy with different source and destination element types}}
74+
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
75+
transform.yield
76+
}
77+
}
78+
79+
// -----
80+
81+
func.func @linalg_copy_to_memref_copy_scalar(%input : f64, %output : memref<128x64xf64>) {
82+
// expected-note @below {{target op}}
83+
linalg.copy ins(%input : f64) outs(%output : memref<128x64xf64>)
84+
return
85+
}
86+
87+
module attributes {transform.with_named_sequence} {
88+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
89+
%0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op
90+
// expected-error @below {{cannot transform a linalg.copy which input has no shape}}
91+
%1 = transform.structured.linalg_copy_to_memref %0 : (!transform.any_op) -> !transform.any_op
92+
transform.yield
93+
}
94+
}

0 commit comments

Comments
 (0)