Skip to content

Commit 6b4aec5

Browse files
committed
[mlir][tensor] Apply InsertSliceOfTransferWriteOpFolder only when transfer_write overwrites
all elements of `insert_slice` Resolves #101708
1 parent b7e51b4 commit 6b4aec5

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2222
#include "mlir/IR/AffineMap.h"
2323
#include "mlir/IR/BuiltinAttributes.h"
24+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2425
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2526
#include "llvm/ADT/TypeSwitch.h"
2627
#include <type_traits>
@@ -67,6 +68,12 @@ class InsertSliceOfTransferWriteOpFolder final
6768

6869
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
6970
PatternRewriter &rewriter) const override;
71+
72+
private:
73+
static bool
74+
doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp,
75+
tensor::InsertSliceOp insertSliceOp,
76+
MLIRContext *context);
7077
};
7178
} // namespace
7279

@@ -136,6 +143,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136143
if (failed(preconditionResult))
137144
return preconditionResult;
138145

146+
if (!doesTransferWriteCoverInsertSlice(writeOp, insertSliceOp,
147+
rewriter.getContext()))
148+
return rewriter.notifyMatchFailure(
149+
insertSliceOp, "transfer_write does not cover insert_slice");
150+
139151
SmallVector<Value> indices(writeOp.getIndices().begin(),
140152
writeOp.getIndices().end());
141153
SmallVector<Value> sourceIndices;
@@ -154,6 +166,25 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154166
return success();
155167
}
156168

169+
bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
170+
vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
171+
MLIRContext *context) {
172+
auto destType = cast<ShapedType>(writeOp.getOperand(0).getType());
173+
auto insertSliceType = insertSliceOp.getSourceType();
174+
175+
if (destType.hasStaticShape() && insertSliceType.hasStaticShape()) {
176+
for (int64_t d = 0, e = insertSliceType.getRank(); d < e; ++d) {
177+
if (destType.getDimSize(d) != insertSliceType.getDimSize(d))
178+
return false;
179+
}
180+
return true;
181+
}
182+
183+
// Todo: ValueBoundsConstraintSet for dynamic shapes.
184+
185+
return true;
186+
}
187+
157188
template <typename OpTy>
158189
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
159190
using OpRewritePattern<OpTy>::OpRewritePattern;

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,22 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
226226

227227
// -----
228228

229+
// CHECK-LABEL: func @insert_slice_of_transfer_write_overwrite_all(
230+
// CHECK-SAME: %[[arg0:.*]]: tensor<1000x1000xf32>, %[[arg1:.*]]: vector<5x6xf32>, %[[arg2:.*]]: index, %[[arg3:.*]]: tensor<100x100xf32>
231+
func.func @insert_slice_of_transfer_write_overwrite_all(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
232+
%c0 = arith.constant 0 : index
233+
234+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
235+
// CHECK: %[[r1:.*]] = vector.transfer_write %[[arg1]], %[[arg3]][%[[c0]], %[[c0]]] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
236+
// CHECK: %[[r2:.*]] = tensor.insert_slice %[[r1]] into %[[arg0]][3, %[[arg2]]] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
237+
// CHECK: return %[[r2]] : tensor<1000x1000xf32>
238+
%0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
239+
%inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
240+
return %inserted_slice : tensor<1000x1000xf32>
241+
}
242+
243+
// -----
244+
229245
// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
230246

231247
// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(

0 commit comments

Comments
 (0)