Skip to content

Commit b5ba23d

Browse files
committed
[mlir][tensor] Apply InsertSliceOfTransferWriteOpFolder only when transfer_write overwrites
all elements of `insert_slice` Resolves #101708
1 parent 57c1e21 commit b5ba23d

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "mlir/IR/BuiltinAttributes.h"
2424
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2525
#include "llvm/ADT/TypeSwitch.h"
26+
#include <cstddef>
27+
#include <sys/_types/_int64_t.h>
2628
#include <type_traits>
2729

2830
namespace mlir {
@@ -67,6 +69,12 @@ class InsertSliceOfTransferWriteOpFolder final
6769

6870
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
6971
PatternRewriter &rewriter) const override;
72+
73+
private:
74+
static bool
75+
doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp,
76+
tensor::InsertSliceOp insertSliceOp,
77+
MLIRContext *context);
7078
};
7179
} // namespace
7280

@@ -136,6 +144,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136144
if (failed(preconditionResult))
137145
return preconditionResult;
138146

147+
if (!doesTransferWriteCoverInsertSlice(writeOp, insertSliceOp,
148+
rewriter.getContext()))
149+
return rewriter.notifyMatchFailure(
150+
insertSliceOp, "transfer_write does not cover insert_slice");
151+
139152
SmallVector<Value> indices(writeOp.getIndices().begin(),
140153
writeOp.getIndices().end());
141154
SmallVector<Value> sourceIndices;
@@ -154,6 +167,13 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154167
return success();
155168
}
156169

170+
bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
171+
vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
172+
MLIRContext *context) {
173+
// Todo
174+
return true;
175+
}
176+
157177
template <typename OpTy>
158178
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
159179
using OpRewritePattern<OpTy>::OpRewritePattern;

mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,15 @@ func.func @insert_slice_of_transfer_write(%t1 : tensor<?x12xf32>, %v : vector<5x
226226

227227
// -----
228228

229+
func.func @insert_slice_of_transfer_write(%arg0: tensor<1000x1000xf32>, %arg1: vector<5x6xf32>, %arg2: index, %arg3: tensor<100x100xf32>) -> tensor<1000x1000xf32> {
230+
%c0 = arith.constant 0 : index
231+
%0 = vector.transfer_write %arg1, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<100x100xf32>
232+
%inserted_slice = tensor.insert_slice %0 into %arg0[3, %arg2] [100, 100] [1, 1] : tensor<100x100xf32> into tensor<1000x1000xf32>
233+
return %inserted_slice : tensor<1000x1000xf32>
234+
}
235+
236+
// -----
237+
229238
// CHECK-DAG: #[[$d0d2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
230239

231240
// CHECK-LABEL: func @insert_slice_of_transfer_write_swappy_rank_extending(

0 commit comments

Comments
 (0)