Skip to content

Commit 93a735a

Browse files
[mlir][Linalg] Add a structured transform to materialize a tensor.insert_slice via a linalg.copy
This is useful to materialize copies explicitly before bufferization and transform them, avoiding the need to rediscover them after bufferization. Differential Revision: https://reviews.llvm.org/D148108
1 parent 63c9d2b commit 93a735a

File tree

4 files changed

+179
-0
lines changed

4 files changed

+179
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class LinalgOp;
2626
} // namespace linalg
2727

2828
namespace tensor {
29+
class InsertSliceOp;
2930
class PackOp;
3031
class PadOp;
3132
class UnPackOp;

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,4 +2005,42 @@ def HoistRedundantTensorSubsetsOp :
20052005
}];
20062006
}
20072007

2008+
//===----------------------------------------------------------------------===//
2009+
// InsertSliceToCopyOp
2010+
//===----------------------------------------------------------------------===//
2011+
2012+
def InsertSliceToCopyOp :
2013+
Op<Transform_Dialect, "structured.insert_slice_to_copy",
2014+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2015+
TransformEachOpTrait, TransformOpInterface]> {
2016+
let description = [{
2017+
Targeted rewrite of an tensor.insert_slice to linalg.copy.
2018+
This is useful to materialize copies explicitly before bufferization and
2019+
transform them, avoiding the need to rediscover them after bufferization.
2020+
2021+
If the insert_slice source is already a linalg.copy, only return the source
2022+
op (i.e. do not create an additional linalg.copy op).
2023+
2024+
#### Return modes:
2025+
2026+
The operation always succeeds and returns a handle to the relevant
2027+
linalg.copy op.
2028+
}];
2029+
2030+
let arguments = (ins TransformHandleTypeInterface:$target);
2031+
let results = (outs TransformHandleTypeInterface:$transformed);
2032+
2033+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
2034+
2035+
let builders = [
2036+
OpBuilder<(ins "Value":$target)>,
2037+
];
2038+
let extraClassDeclaration = [{
2039+
::mlir::DiagnosedSilenceableFailure applyToOne(
2040+
::mlir::tensor::InsertSliceOp target,
2041+
::mlir::transform::ApplyToEachResultList &results,
2042+
::mlir::transform::TransformState &state);
2043+
}];
2044+
}
2045+
20082046
#endif // LINALG_TRANSFORM_OPS

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3232,6 +3232,36 @@ transform::HoistRedundantTensorSubsetsOp::applyToOne(
32323232
return DiagnosedSilenceableFailure::success();
32333233
}
32343234

3235+
//===----------------------------------------------------------------------===//
3236+
// InsertSliceToCopyOp
3237+
//===----------------------------------------------------------------------===//
3238+
3239+
DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne(
3240+
tensor::InsertSliceOp target, transform::ApplyToEachResultList &results,
3241+
transform::TransformState &state) {
3242+
if (auto copySource = target.getSource().getDefiningOp<linalg::CopyOp>()) {
3243+
results.push_back(copySource);
3244+
return DiagnosedSilenceableFailure::success();
3245+
}
3246+
3247+
TrackingListener listener(state, *this);
3248+
IRRewriter rewriter(target->getContext(), &listener);
3249+
rewriter.setInsertionPoint(target);
3250+
Value extracted = rewriter.create<tensor::ExtractSliceOp>(
3251+
target.getLoc(), target.getDest(), target.getMixedOffsets(),
3252+
target.getMixedSizes(), target.getMixedStrides());
3253+
Value copied = rewriter
3254+
.create<linalg::CopyOp>(target.getLoc(),
3255+
target.getSource(), extracted)
3256+
.getResult(0);
3257+
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
3258+
target, copied, target.getDest(), target.getMixedOffsets(),
3259+
target.getMixedSizes(), target.getMixedStrides());
3260+
3261+
results.push_back(copied.getDefiningOp());
3262+
return DiagnosedSilenceableFailure::success();
3263+
}
3264+
32353265
//===----------------------------------------------------------------------===//
32363266
// Transform op registration
32373267
//===----------------------------------------------------------------------===//
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func @insert_slice_to_copy
4+
// CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
5+
// CHECK-SAME: %[[O:.*]]: tensor<?x?xf32>,
6+
// CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index,
7+
// CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index,
8+
// CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index,
9+
// CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index,
10+
// CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index,
11+
// CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index)
12+
func.func @insert_slice_to_copy(
13+
%I : tensor<2x3xf32>, %O : tensor<?x?xf32>,
14+
%off0 : index, %off1 : index,
15+
%sz0 : index, %sz1 : index,
16+
%st0 : index, %st1 : index) -> tensor<?x?xf32> {
17+
18+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
19+
// CHECK-SAME: : tensor<?x?xf32> to tensor<2x3xf32>
20+
// CHECK: linalg.copy ins(%[[I]] : tensor<2x3xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<2x3xf32>) -> tensor<2x3xf32>
21+
// CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
22+
// CHECK-SAME: : tensor<2x3xf32> into tensor<?x?xf32>
23+
24+
%0 = tensor.insert_slice %I into %O[%off0, %off1] [2, 3] [%st0, %st1]
25+
: tensor<2x3xf32> into tensor<?x?xf32>
26+
return %0 : tensor<?x?xf32>
27+
}
28+
29+
transform.sequence failures(propagate) {
30+
^bb1(%arg1: !transform.any_op):
31+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32+
%1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op
33+
transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
34+
}
35+
36+
// -----
37+
38+
// CHECK-LABEL: func @insert_slice_to_copy
39+
// CHECK-SAME: %[[I:[0-9a-zA-Z]+]]: tensor<?x?xf32>
40+
// CHECK-SAME: %[[O:[0-9a-zA-Z]+]]: tensor<?x?xf32>,
41+
// CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index,
42+
// CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index,
43+
// CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index,
44+
// CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index,
45+
// CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index,
46+
// CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index)
47+
func.func @insert_slice_to_copy(
48+
%I : tensor<?x?xf32>, %O : tensor<?x?xf32>,
49+
%off0 : index, %off1 : index,
50+
%sz0 : index, %sz1 : index,
51+
%st0 : index, %st1 : index) -> tensor<?x?xf32> {
52+
53+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [%[[SZ0]], %[[SZ1]]] [1, 1]
54+
// CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
55+
// CHECK: linalg.copy ins(%[[I]] : tensor<?x?xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<?x?xf32>) -> tensor<?x?xf32>
56+
// CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [%[[SZ0]], %[[SZ1]]] [1, 1]
57+
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x?xf32>
58+
59+
%0 = tensor.insert_slice %I into %O[%off0, %off1] [%sz0, %sz1] [1, 1]
60+
: tensor<?x?xf32> into tensor<?x?xf32>
61+
return %0 : tensor<?x?xf32>
62+
}
63+
64+
transform.sequence failures(propagate) {
65+
^bb1(%arg1: !transform.any_op):
66+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
67+
%1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op
68+
transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
69+
}
70+
71+
// -----
72+
// CHECK-LABEL: func @insert_slice_to_copy
73+
// CHECK-SAME: %[[I:.*]]: tensor<2x3xf32>
74+
// CHECK-SAME: %[[O:.*]]: tensor<?x?xf32>,
75+
// CHECK-SAME: %[[OFF0:[0-9a-zA-Z]+]]: index,
76+
// CHECK-SAME: %[[OFF1:[0-9a-zA-Z]+]]: index,
77+
// CHECK-SAME: %[[SZ0:[0-9a-zA-Z]+]]: index,
78+
// CHECK-SAME: %[[SZ1:[0-9a-zA-Z]+]]: index,
79+
// CHECK-SAME: %[[ST0:[0-9a-zA-Z]+]]: index,
80+
// CHECK-SAME: %[[ST1:[0-9a-zA-Z]+]]: index)
81+
func.func @insert_slice_to_copy(
82+
%I : tensor<2x3xf32>, %O : tensor<?x?xf32>,
83+
%off0 : index, %off1 : index,
84+
%sz0 : index, %sz1 : index,
85+
%st0 : index, %st1 : index) -> tensor<?x?xf32> {
86+
87+
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
88+
// CHECK-SAME: : tensor<?x?xf32> to tensor<2x3xf32>
89+
// CHECK: linalg.copy ins(%[[I]] : tensor<2x3xf32>) outs(%[[EXTRACTED_SLICE]] : tensor<2x3xf32>) -> tensor<2x3xf32>
90+
// CHECK-NOT: linalg.copy
91+
// CHECK: tensor.insert_slice %{{.*}} into %[[O]][%[[OFF0]], %[[OFF1]]] [2, 3] [%[[ST0]], %[[ST1]]]
92+
// CHECK-SAME: : tensor<2x3xf32> into tensor<?x?xf32>
93+
94+
%extracted_slice = tensor.extract_slice %O[%off0, %off1] [2, 3] [%st0, %st1]
95+
: tensor<?x?xf32> to tensor<2x3xf32>
96+
%0 = linalg.copy ins(%I : tensor<2x3xf32>) outs(%extracted_slice
97+
: tensor<2x3xf32>) -> tensor<2x3xf32>
98+
%inserted_slice = tensor.insert_slice %0 into %O[%off0, %off1] [2, 3] [%st0, %st1]
99+
: tensor<2x3xf32> into tensor<?x?xf32>
100+
101+
return %inserted_slice : tensor<?x?xf32>
102+
}
103+
104+
transform.sequence failures(propagate) {
105+
^bb1(%arg1: !transform.any_op):
106+
%0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
107+
%1 = transform.structured.insert_slice_to_copy %0 : (!transform.any_op) -> !transform.any_op
108+
transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy">
109+
}
110+

0 commit comments

Comments
 (0)