@@ -211,64 +211,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
211
211
212
212
namespace {
213
213
214
- // / A pattern for ops that implement `MaskableOpInterface` and that _might_ be
215
- // / masked (i.e. inside `vector.mask` Op region). In particular:
216
- // / 1. It matches `SourceOp` operation, Op.
217
- // / 2. If Op is masked, retrieves the mask and updates the insertion point to
218
- // / avoid inserting new ops into `vector.mask` Op region (which only allows
219
- // / one Op). If the Op is not masked, this step is a nop.
220
- // / 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
221
- // / required) in the matched `vector.mask` operation from step 2.
222
- // /
223
- // / It frees the patterns implementing this class from worrying about the
224
- // / logic to update the insertion point. However, those patterns are still
225
- // / responsible for providing an updated version of:
226
- // / * the source Op when mask _is not_ present,
227
- // / * the source Op *and* the mask Op when mask _is_ present.
228
- template <class SourceOp >
229
- struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
230
- using OpRewritePattern<SourceOp>::OpRewritePattern;
231
-
232
- private:
233
- LogicalResult matchAndRewrite (SourceOp sourceOp,
234
- PatternRewriter &rewriter) const final {
235
- auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation ());
236
- if (!maskableOp)
237
- return failure ();
238
-
239
- // Op to update
240
- Operation *rootOp = sourceOp;
241
-
242
- // If this Op is masked:
243
- // * update the insertion point to avoid inserting into the vector.mask
244
- // Op region,
245
- // * update the Op to rewrite so that it's the parent vector.mask Op
246
- OpBuilder::InsertionGuard guard (rewriter);
247
- MaskingOpInterface maskOp;
248
- if (maskableOp.isMasked ()) {
249
- maskOp = maskableOp.getMaskingOp ();
250
- rewriter.setInsertionPoint (maskOp);
251
- rootOp = maskOp;
252
- }
253
-
254
- FailureOr<Value> newOp =
255
- matchAndRewriteMaskableOp (sourceOp, maskOp, rewriter);
256
- if (failed (newOp))
257
- return failure ();
258
-
259
- rewriter.replaceOp (rootOp, *newOp);
260
- return success ();
261
- }
262
-
263
- public:
264
- // Matches SourceOp that can potentially be masked with `maskingOp`. If the
265
- // latter is present, returns an updated masking op (with a replacement for
266
- // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
267
- virtual FailureOr<Value>
268
- matchAndRewriteMaskableOp (SourceOp sourceOp, MaskingOpInterface maskingOp,
269
- PatternRewriter &rewriter) const = 0 ;
270
- };
271
-
272
214
// / Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
273
215
// / semantics to:
274
216
// / ```
@@ -283,7 +225,7 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
283
225
// / This only kicks in when VectorTransformsOptions is set to OuterProduct and
284
226
// / the vector.contract op is a row-major matrix multiply.
285
227
class ContractionOpToMatmulOpLowering
286
- : public MaskableOpRewritePattern<vector::ContractionOp> {
228
+ : public vector:: MaskableOpRewritePattern<vector::ContractionOp> {
287
229
public:
288
230
using MaskableOpRewritePattern::MaskableOpRewritePattern;
289
231
0 commit comments