23
23
#include " mlir/IR/BuiltinAttributes.h"
24
24
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
25
25
#include " llvm/ADT/TypeSwitch.h"
26
+ #include < cstddef>
27
+ #include < sys/_types/_int64_t.h>
26
28
#include < type_traits>
27
29
28
30
namespace mlir {
@@ -67,6 +69,12 @@ class InsertSliceOfTransferWriteOpFolder final
67
69
68
70
LogicalResult matchAndRewrite (tensor::InsertSliceOp insertSliceOp,
69
71
PatternRewriter &rewriter) const override ;
72
+
73
+ private:
74
+ static bool
75
+ doesTransferWriteCoverInsertSlice (vector::TransferWriteOp writeOp,
76
+ tensor::InsertSliceOp insertSliceOp,
77
+ MLIRContext *context);
70
78
};
71
79
} // namespace
72
80
@@ -136,6 +144,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136
144
if (failed (preconditionResult))
137
145
return preconditionResult;
138
146
147
+ if (!doesTransferWriteCoverInsertSlice (writeOp, insertSliceOp,
148
+ rewriter.getContext ()))
149
+ return rewriter.notifyMatchFailure (
150
+ insertSliceOp, " transfer_write does not cover insert_slice" );
151
+
139
152
SmallVector<Value> indices (writeOp.getIndices ().begin (),
140
153
writeOp.getIndices ().end ());
141
154
SmallVector<Value> sourceIndices;
@@ -154,6 +167,13 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154
167
return success ();
155
168
}
156
169
170
+ bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice (
171
+ vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
172
+ MLIRContext *context) {
173
+ // Todo
174
+ return true ;
175
+ }
176
+
157
177
template <typename OpTy>
158
178
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern <OpTy> {
159
179
using OpRewritePattern<OpTy>::OpRewritePattern;
0 commit comments