Skip to content

Commit 69814ed

Browse files
committed
fixup! [mlir][vector] Adds pattern rewrite for maskable Ops
Rename the pattern, simplify code, restore docs
1 parent 2f91610 commit 69814ed

File tree

1 file changed

+40
-26
lines changed

1 file changed

+40
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -226,27 +226,27 @@ namespace {
226226
/// * the source Op when mask _is not_ present,
227227
/// * the source Op *and* the mask Op when mask _is_ present.
228228
template <class SourceOp>
229-
struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
229+
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
230230
using OpRewritePattern<SourceOp>::OpRewritePattern;
231231

232232
private:
233233
LogicalResult matchAndRewrite(SourceOp sourceOp,
234234
PatternRewriter &rewriter) const final {
235-
auto maskableOp =
236-
dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
235+
auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
237236
if (!maskableOp)
238237
return failure();
239238

240-
// Retrieve the mask if present
241-
MaskingOpInterface maskOp;
242-
if (maskableOp.isMasked())
243-
maskOp = maskableOp.getMaskingOp();
239+
// Op to update
240+
Operation *rootOp = sourceOp;
244241

245-
// If this Op is masked, update the insertion point to avoid inserting into
246-
// the vector.mask Op region.
242+
// If this Op is masked:
243+
// * update the insertion point to avoid inserting into the vector.mask
244+
// Op region,
245+
// * update the Op to rewrite so that it's the parent vector.mask Op
247246
OpBuilder::InsertionGuard guard(rewriter);
248-
Operation *rootOp = sourceOp;
249-
if (maskOp) {
247+
MaskingOpInterface maskOp;
248+
if (maskableOp.isMasked()) {
249+
maskOp = maskableOp.getMaskingOp();
250250
rewriter.setInsertionPoint(maskOp);
251251
rootOp = maskOp;
252252
}
@@ -283,9 +283,9 @@ struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
283283
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
284284
/// the vector.contract op is a row-major matrix multiply.
285285
class ContractionOpToMatmulOpLowering
286-
: public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
286+
: public MaskableOpRewritePattern<vector::ContractionOp> {
287287
public:
288-
using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
288+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
289289

290290
using FilterConstraintType =
291291
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -298,7 +298,7 @@ class ContractionOpToMatmulOpLowering
298298
vector::VectorTransformsOptions vectorTransformOptions,
299299
MLIRContext *context, PatternBenefit benefit = 1,
300300
FilterConstraintType constraint = defaultFilter)
301-
: MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
301+
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
302302
vectorTransformOptions(vectorTransformOptions),
303303
filter(std::move(constraint)) {}
304304

@@ -328,9 +328,9 @@ class ContractionOpToMatmulOpLowering
328328
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
329329
/// the vector.contract op is a row-major matrix multiply.
330330
class ContractionOpToOuterProductOpLowering
331-
: public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
331+
: public MaskableOpRewritePattern<vector::ContractionOp> {
332332
public:
333-
using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
333+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
334334

335335
using FilterConstraintType =
336336
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -343,7 +343,7 @@ class ContractionOpToOuterProductOpLowering
343343
vector::VectorTransformsOptions vectorTransformOptions,
344344
MLIRContext *context, PatternBenefit benefit = 1,
345345
FilterConstraintType constraint = defaultFilter)
346-
: MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
346+
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
347347
vectorTransformOptions(vectorTransformOptions),
348348
filter(std::move(constraint)) {}
349349

@@ -376,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
376376
/// This only kicks in when VectorTransformsOptions is set to Dot and
377377
/// the vector.contract op is a row-major matmul or matvec.
378378
class ContractionOpToDotLowering
379-
: public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
379+
: public MaskableOpRewritePattern<vector::ContractionOp> {
380380
public:
381-
using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
381+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
382382

383383
using FilterConstraintType =
384384
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -391,7 +391,7 @@ class ContractionOpToDotLowering
391391
vector::VectorTransformsOptions vectorTransformOptions,
392392
MLIRContext *context, PatternBenefit benefit = 1,
393393
const FilterConstraintType &constraint = defaultFilter)
394-
: MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
394+
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
395395
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
396396

397397
FailureOr<Value>
@@ -404,10 +404,24 @@ class ContractionOpToDotLowering
404404
FilterConstraintType filter;
405405
};
406406

407+
/// Progressive lowering of ContractionOp.
408+
///
409+
/// One:
410+
/// %x = vector.contract with at least one free/batch dimension
411+
/// is replaced by:
412+
/// %a = vector.contract with one less free/batch dimension
413+
/// %b = vector.contract with one less free/batch dimension
414+
/// ..
415+
/// %x = combine %a %b ..
416+
/// until a pure contraction is reached (no free/batch dimensions),
417+
/// which is replaced by a dot-product.
418+
///
419+
/// This only kicks in when either VectorTransformsOptions is set
420+
/// to Dot or when other contraction patterns fail.
407421
class ContractionOpLowering
408-
: public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
422+
: public MaskableOpRewritePattern<vector::ContractionOp> {
409423
public:
410-
using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
424+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
411425
using FilterConstraintType =
412426
std::function<LogicalResult(vector::ContractionOp op)>;
413427

@@ -418,7 +432,7 @@ class ContractionOpLowering
418432
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
419433
MLIRContext *context, PatternBenefit benefit = 1,
420434
FilterConstraintType constraint = defaultFilter)
421-
: MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
435+
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
422436
vectorTransformOptions(vectorTransformOptions),
423437
filter(std::move(constraint)) {}
424438

@@ -826,8 +840,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
826840
/// Lower vector.contract with all size one reduction dimensions to
827841
/// elementwise ops when possible.
828842
struct ContractOpToElementwise
829-
: public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
830-
using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
843+
: public MaskableOpRewritePattern<vector::ContractionOp> {
844+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
831845
using FilterConstraintType =
832846
std::function<LogicalResult(vector::ContractionOp op)>;
833847
static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -837,7 +851,7 @@ struct ContractOpToElementwise
837851
vector::VectorTransformsOptions vectorTransformOptions,
838852
MLIRContext *context, PatternBenefit benefit = 1,
839853
const FilterConstraintType &constraint = defaultFilter)
840-
: MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
854+
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
841855
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
842856

843857
FailureOr<Value>

0 commit comments

Comments
 (0)