Skip to content

Commit 760ffa4

Browse files
authored
[mlir][tensor] Apply InsertSliceOfTransferWriteOpFolder only when transfer_write overwrites all elements of insert_slice (#108803)
Resolves #101708 The updated logic now correctly checks if `transfer_write` completely overwrites `insert_slice` and only then applies the rewrite for this pattern. This check currently covers static sizes, for dynamic sizes value bounds analysis is needed (see `TODO:`).
1 parent 38a8000 commit 760ffa4

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2222
#include "mlir/IR/AffineMap.h"
2323
#include "mlir/IR/BuiltinAttributes.h"
24+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2425
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2526
#include "llvm/ADT/TypeSwitch.h"
2627
#include <type_traits>
@@ -67,6 +68,10 @@ class InsertSliceOfTransferWriteOpFolder final
6768

6869
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
6970
PatternRewriter &rewriter) const override;
71+
72+
private:
73+
static bool
74+
doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
7075
};
7176
} // namespace
7277

@@ -136,6 +141,10 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136141
if (failed(preconditionResult))
137142
return preconditionResult;
138143

144+
if (!doesTransferWriteCoverInsertSlice(writeOp))
145+
return rewriter.notifyMatchFailure(
146+
insertSliceOp, "transfer_write does not cover insert_slice");
147+
139148
SmallVector<Value> indices(writeOp.getIndices().begin(),
140149
writeOp.getIndices().end());
141150
SmallVector<Value> sourceIndices;
@@ -154,6 +163,17 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154163
return success();
155164
}
156165

166+
bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
167+
vector::TransferWriteOp writeOp) {
168+
if (writeOp.getShapedType().hasStaticShape())
169+
return llvm::equal(writeOp.getVectorType().getShape(),
170+
writeOp.getShapedType().getShape());
171+
172+
// TODO: Use ValueBoundsConstraintSet for dynamic shapes.
173+
174+
return false;
175+
}
176+
157177
template <typename OpTy>
158178
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
159179
using OpRewritePattern<OpTy>::OpRewritePattern;

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,6 @@ func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor<?x?x?
144144

145145
// -----
146146

147-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
148-
149147
// CHECK: func @fold_vector_transfer_write_with_rank_reduced_insert_slice
150148
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
151149
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
@@ -155,18 +153,16 @@ func.func @transfer_read_of_extract_slice_swappy_rank_reducing(%t : tensor<?x?x?
155153
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
156154
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
157155
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
156+
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: tensor<?x?xf32>
158157
func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
159158
%arg0 : tensor<?x?x?xf32>,
160159
%arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
161160
%arg5: index, %arg6 : index, %arg7 : index,
162161
%st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
163162
%cst = arith.constant 0.0 : f32
164163

165-
// CHECK-NOT: insert_slice
166-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
167-
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
168-
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
169-
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?x?xf32
164+
// CHECK-DAG: %[[r1:.*]] = vector.transfer_write %[[ARG1]], %[[ARG8]][%[[ARG6]], %[[ARG7]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?xf32>
165+
// CHECK-DAG: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[ARG0]][0, %[[ARG2]], %[[ARG3]]] [1, %[[ARG4]], %[[ARG5]]] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
170166
%0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
171167
: vector<4xf32>, tensor<?x?xf32>
172168
%1 = tensor.insert_slice %0 into %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1]
@@ -176,9 +172,6 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
176172

177173
// -----
178174

179-
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
180-
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)>
181-
182175
// CHECK: func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice
183176
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
184177
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32>
@@ -188,19 +181,16 @@ func.func @fold_vector_transfer_write_with_rank_reduced_insert_slice(
188181
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
189182
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
190183
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
184+
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: tensor<?x?xf32>
191185
func.func @fold_vector_transfer_write_with_inner_rank_reduced_insert_slice(
192186
%arg0 : tensor<?x?x?xf32>,
193187
%arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index,
194188
%arg5: index, %arg6 : index, %arg7 : index,
195189
%st : tensor<?x?xf32>) -> tensor<?x?x?xf32> {
196190
%cst = arith.constant 0.0 : f32
197191

198-
// CHECK-NOT: insert_slice
199-
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
200-
// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG6]]]
201-
// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG7]]]
202-
// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]]
203-
// CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, tensor<?x?x?xf32
192+
// CHECK-DAG: %[[r1:.*]] = vector.transfer_write %[[ARG1]], %[[ARG8]][%[[ARG6]], %[[ARG7]]] {in_bounds = [true]} : vector<4xf32>, tensor<?x?xf32>
193+
// CHECK-DAG: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[ARG0]][%[[ARG2]], %[[ARG3]], 0] [%[[ARG4]], %[[ARG5]], 1] [1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?xf32>
204194
%0 = vector.transfer_write %arg1, %st[%arg6, %arg7] {in_bounds = [true]}
205195
: vector<4xf32>, tensor<?x?xf32>
206196
%1 = tensor.insert_slice %0 into %arg0[%arg2, %arg3, 0] [%arg4, %arg5, 1] [1, 1, 1]
@@ -226,6 +216,24 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
226216

227217
// -----
228218

219+
// This test is negative since `transfer_write` only
220+
// writes to `5x6` of the `100x100` elements of `%arg3`
221+
// CHECK-LABEL: func @insert_slice_of_transfer_write_overwrite_all(
222+
// CHECK-SAME: %[[arg0:.*]]: tensor<1000x1000xf32>, %[[arg1:.*]]: vector<5x6xf32>, %[[arg2:.*]]: index, %[[arg3:.*]]: tensor<100x100xf32>
223+
func.func @insert_slice_of_transfer_write_overwrite_all(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
224+
%c0 = arith.constant 0 : index
225+
226+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
227+
// CHECK: %[[r1:.*]] = vector.transfer_write %[[arg1]], %[[arg3]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
228+
// CHECK: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[arg0]][3, %[[arg2]]] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
229+
// CHECK: return %[[r2]] : tensor<1000x1000xf32>
230+
%0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
231+
%inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
232+
return %inserted_slice : tensor<1000x1000xf32>
233+
}
234+
235+
// -----
236+
229237
// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
230238

231239
// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(

0 commit comments

Comments
 (0)