21
21
#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
22
22
#include " mlir/IR/AffineMap.h"
23
23
#include " mlir/IR/BuiltinAttributes.h"
24
+ #include " mlir/Interfaces/ValueBoundsOpInterface.h"
24
25
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26
#include " llvm/ADT/TypeSwitch.h"
26
27
#include < type_traits>
@@ -67,6 +68,12 @@ class InsertSliceOfTransferWriteOpFolder final
67
68
68
69
LogicalResult matchAndRewrite (tensor::InsertSliceOp insertSliceOp,
69
70
PatternRewriter &rewriter) const override ;
71
+
72
+ private:
73
+ static bool
74
+ doesTransferWriteCoverInsertSlice (vector::TransferWriteOp writeOp,
75
+ tensor::InsertSliceOp insertSliceOp,
76
+ MLIRContext *context);
70
77
};
71
78
} // namespace
72
79
@@ -136,6 +143,11 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
136
143
if (failed (preconditionResult))
137
144
return preconditionResult;
138
145
146
+ if (!doesTransferWriteCoverInsertSlice (writeOp, insertSliceOp,
147
+ rewriter.getContext ()))
148
+ return rewriter.notifyMatchFailure (
149
+ insertSliceOp, " transfer_write does not cover insert_slice" );
150
+
139
151
SmallVector<Value> indices (writeOp.getIndices ().begin (),
140
152
writeOp.getIndices ().end ());
141
153
SmallVector<Value> sourceIndices;
@@ -154,6 +166,25 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
154
166
return success ();
155
167
}
156
168
169
+ bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice (
170
+ vector::TransferWriteOp writeOp, tensor::InsertSliceOp insertSliceOp,
171
+ MLIRContext *context) {
172
+ auto destType = cast<ShapedType>(writeOp.getOperand (0 ).getType ());
173
+ auto insertSliceType = insertSliceOp.getSourceType ();
174
+
175
+ if (destType.hasStaticShape () && insertSliceType.hasStaticShape ()) {
176
+ for (int64_t d = 0 , e = insertSliceType.getRank (); d < e; ++d) {
177
+ if (destType.getDimSize (d) != insertSliceType.getDimSize (d))
178
+ return false ;
179
+ }
180
+ return true ;
181
+ }
182
+
183
+ // Todo: ValueBoundsConstraintSet for dynamic shapes.
184
+
185
+ return true ;
186
+ }
187
+
157
188
template <typename OpTy>
158
189
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern <OpTy> {
159
190
using OpRewritePattern<OpTy>::OpRewritePattern;
0 commit comments