@@ -226,27 +226,27 @@ namespace {
226
226
// / * the source Op when mask _is not_ present,
227
227
// / * the source Op *and* the mask Op when mask _is_ present.
228
228
template <class SourceOp >
229
- struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
229
+ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
230
230
using OpRewritePattern<SourceOp>::OpRewritePattern;
231
231
232
232
private:
233
233
LogicalResult matchAndRewrite (SourceOp sourceOp,
234
234
PatternRewriter &rewriter) const final {
235
- auto maskableOp =
236
- dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation ());
235
+ auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation ());
237
236
if (!maskableOp)
238
237
return failure ();
239
238
240
- // Retrieve the mask if present
241
- MaskingOpInterface maskOp;
242
- if (maskableOp.isMasked ())
243
- maskOp = maskableOp.getMaskingOp ();
239
+ // Op to update
240
+ Operation *rootOp = sourceOp;
244
241
245
- // If this Op is masked, update the insertion point to avoid inserting into
246
- // the vector.mask Op region.
242
+ // If this Op is masked:
243
+ // * update the insertion point to avoid inserting into the vector.mask
244
+ // Op region,
245
+ // * update the Op to rewrite so that it's the parent vector.mask Op
247
246
OpBuilder::InsertionGuard guard (rewriter);
248
- Operation *rootOp = sourceOp;
249
- if (maskOp) {
247
+ MaskingOpInterface maskOp;
248
+ if (maskableOp.isMasked ()) {
249
+ maskOp = maskableOp.getMaskingOp ();
250
250
rewriter.setInsertionPoint (maskOp);
251
251
rootOp = maskOp;
252
252
}
@@ -283,9 +283,9 @@ struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
283
283
// / This only kicks in when VectorTransformsOptions is set to OuterProduct and
284
284
// / the vector.contract op is a row-major matrix multiply.
285
285
class ContractionOpToMatmulOpLowering
286
- : public MaybeMaskedOpRewritePattern <vector::ContractionOp> {
286
+ : public MaskableOpRewritePattern <vector::ContractionOp> {
287
287
public:
288
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern ;
288
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
289
289
290
290
using FilterConstraintType =
291
291
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -298,7 +298,7 @@ class ContractionOpToMatmulOpLowering
298
298
vector::VectorTransformsOptions vectorTransformOptions,
299
299
MLIRContext *context, PatternBenefit benefit = 1 ,
300
300
FilterConstraintType constraint = defaultFilter)
301
- : MaybeMaskedOpRewritePattern <vector::ContractionOp>(context, benefit),
301
+ : MaskableOpRewritePattern <vector::ContractionOp>(context, benefit),
302
302
vectorTransformOptions (vectorTransformOptions),
303
303
filter (std::move(constraint)) {}
304
304
@@ -328,9 +328,9 @@ class ContractionOpToMatmulOpLowering
328
328
// / This only kicks in when VectorTransformsOptions is set to OuterProduct and
329
329
// / the vector.contract op is a row-major matrix multiply.
330
330
class ContractionOpToOuterProductOpLowering
331
- : public MaybeMaskedOpRewritePattern <vector::ContractionOp> {
331
+ : public MaskableOpRewritePattern <vector::ContractionOp> {
332
332
public:
333
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern ;
333
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
334
334
335
335
using FilterConstraintType =
336
336
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -343,7 +343,7 @@ class ContractionOpToOuterProductOpLowering
343
343
vector::VectorTransformsOptions vectorTransformOptions,
344
344
MLIRContext *context, PatternBenefit benefit = 1 ,
345
345
FilterConstraintType constraint = defaultFilter)
346
- : MaybeMaskedOpRewritePattern <vector::ContractionOp>(context, benefit),
346
+ : MaskableOpRewritePattern <vector::ContractionOp>(context, benefit),
347
347
vectorTransformOptions (vectorTransformOptions),
348
348
filter (std::move(constraint)) {}
349
349
@@ -376,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
376
376
// / This only kicks in when VectorTransformsOptions is set to Dot and
377
377
// / the vector.contract op is a row-major matmul or matvec.
378
378
class ContractionOpToDotLowering
379
- : public MaybeMaskedOpRewritePattern <vector::ContractionOp> {
379
+ : public MaskableOpRewritePattern <vector::ContractionOp> {
380
380
public:
381
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern ;
381
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
382
382
383
383
using FilterConstraintType =
384
384
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -391,7 +391,7 @@ class ContractionOpToDotLowering
391
391
vector::VectorTransformsOptions vectorTransformOptions,
392
392
MLIRContext *context, PatternBenefit benefit = 1 ,
393
393
const FilterConstraintType &constraint = defaultFilter)
394
- : MaybeMaskedOpRewritePattern <vector::ContractionOp>(context, benefit),
394
+ : MaskableOpRewritePattern <vector::ContractionOp>(context, benefit),
395
395
vectorTransformOptions (vectorTransformOptions), filter(defaultFilter) {}
396
396
397
397
FailureOr<Value>
@@ -404,10 +404,24 @@ class ContractionOpToDotLowering
404
404
FilterConstraintType filter;
405
405
};
406
406
407
+ // / Progressive lowering of ContractionOp.
408
+ // /
409
+ // / One:
410
+ // / %x = vector.contract with at least one free/batch dimension
411
+ // / is replaced by:
412
+ // / %a = vector.contract with one less free/batch dimension
413
+ // / %b = vector.contract with one less free/batch dimension
414
+ // / ..
415
+ // / %x = combine %a %b ..
416
+ // / until a pure contraction is reached (no free/batch dimensions),
417
+ // / which is replaced by a dot-product.
418
+ // /
419
+ // / This only kicks in when either VectorTransformsOptions is set
420
+ // / to Dot or when other contraction patterns fail.
407
421
class ContractionOpLowering
408
- : public MaybeMaskedOpRewritePattern <vector::ContractionOp> {
422
+ : public MaskableOpRewritePattern <vector::ContractionOp> {
409
423
public:
410
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern ;
424
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
411
425
using FilterConstraintType =
412
426
std::function<LogicalResult(vector::ContractionOp op)>;
413
427
@@ -418,7 +432,7 @@ class ContractionOpLowering
418
432
ContractionOpLowering (vector::VectorTransformsOptions vectorTransformOptions,
419
433
MLIRContext *context, PatternBenefit benefit = 1 ,
420
434
FilterConstraintType constraint = defaultFilter)
421
- : MaybeMaskedOpRewritePattern <vector::ContractionOp>(context, benefit),
435
+ : MaskableOpRewritePattern <vector::ContractionOp>(context, benefit),
422
436
vectorTransformOptions (vectorTransformOptions),
423
437
filter (std::move(constraint)) {}
424
438
@@ -826,8 +840,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
826
840
// / Lower vector.contract with all size one reduction dimensions to
827
841
// / elementwise ops when possible.
828
842
struct ContractOpToElementwise
829
- : public MaybeMaskedOpRewritePattern <vector::ContractionOp> {
830
- using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern ;
843
+ : public MaskableOpRewritePattern <vector::ContractionOp> {
844
+ using MaskableOpRewritePattern::MaskableOpRewritePattern ;
831
845
using FilterConstraintType =
832
846
std::function<LogicalResult(vector::ContractionOp op)>;
833
847
static LogicalResult defaultFilter (vector::ContractionOp op) {
@@ -837,7 +851,7 @@ struct ContractOpToElementwise
837
851
vector::VectorTransformsOptions vectorTransformOptions,
838
852
MLIRContext *context, PatternBenefit benefit = 1 ,
839
853
const FilterConstraintType &constraint = defaultFilter)
840
- : MaybeMaskedOpRewritePattern <vector::ContractionOp>(context, benefit),
854
+ : MaskableOpRewritePattern <vector::ContractionOp>(context, benefit),
841
855
vectorTransformOptions (vectorTransformOptions), filter(defaultFilter) {}
842
856
843
857
FailureOr<Value>
0 commit comments