Skip to content

[mlir][vector] Adds pattern rewrite for maskable Ops #83827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Mar 4, 2024

Adds a generic pattern rewrite for maskable Ops, MaskableOpRewritePattern,
that will work for both masked and un-masked cases, e.g. for both:

  • vector.mask {vector.contract} (masked), and
  • vector.contract (not masked).

This helps to reduce code-duplication and standardise how we implement such
patterns.

Fixes #78787

@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Adds a generic pattern rewrite for maskable Ops that we would like to
work for both masked and un-masked cases, e.g. for both vector.mask {vector.contract} and vector.contract (this is a very contrived
example - just to demonstrate the idea).

This helps to reduce code-duplication and standardise how we implement
such patterns.

Fixes #78787


Full diff: https://github.com/llvm/llvm-project/pull/83827.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+143-108)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 0eaf9f71a37d21..6480b295c49cb6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -41,7 +41,6 @@ using namespace mlir::vector;
 //===----------------------------------------------------------------------===//
 // Helper functions
 //===----------------------------------------------------------------------===//
-
 // Helper to find an index in an affine map.
 static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
@@ -212,6 +211,64 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
 
 namespace {
 
+/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
+/// masked (i.e. inside `vector.mask` Op region). In particular:
+///   1. It matches `SourceOp` operation, Op.
+///   2. If Op is masked, retrieves the mask and updates the insertion point to
+///   avoid inserting new ops into `vector.mask` Op region (which only allows
+///   one Op). If the Op is not masked, this step is a nop.
+///   3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
+///   required) in the matched `vector.mask` operation from step 2.
+///
+/// It frees the patterns implementing this class from worrying about the
+/// logic to update the insertion point. However, those patterns are still
+/// responsible for providing an updated version of:
+///   * the source Op when mask _is not_ present,
+///   * the source Op *and* the mask Op when mask _is_ present.
+template <class SourceOp>
+struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
+  using OpRewritePattern<SourceOp>::OpRewritePattern;
+
+private:
+  LogicalResult matchAndRewrite(SourceOp sourceOp,
+                                PatternRewriter &rewriter) const final {
+    auto maskableOp =
+        dyn_cast_if_present<MaskableOpInterface>(sourceOp.getOperation());
+    if (!maskableOp)
+      return failure();
+
+    // Retrieve the mask if present
+    MaskingOpInterface maskOp;
+    if (maskableOp.isMasked())
+      maskOp = maskableOp.getMaskingOp();
+
+    // If this Op is masked, update the insertion point to avoid inserting into
+    // the vector.mask Op region.
+    OpBuilder::InsertionGuard guard(rewriter);
+    Operation *rootOp = sourceOp;
+    if (maskOp) {
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+    }
+
+    FailureOr<Value> newOp =
+        matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
+    if (failed(newOp))
+      return failure();
+
+    rewriter.replaceOp(rootOp, *newOp);
+    return success();
+  }
+
+public:
+  // Matches SourceOp that can potentially be masked with `maskingOp`. If the
+  // latter is present, returns an updated masking op (with a replacement for
+  // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
+  virtual FailureOr<Value>
+  matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const = 0;
+};
+
 /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
 /// semantics to:
 /// ```
@@ -226,9 +283,9 @@ namespace {
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToMatmulOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -241,12 +298,13 @@ class ContractionOpToMatmulOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -270,9 +328,9 @@ class ContractionOpToMatmulOpLowering
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct and
 /// the vector.contract op is a row-major matrix multiply.
 class ContractionOpToOuterProductOpLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -285,12 +343,13 @@ class ContractionOpToOuterProductOpLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -317,9 +376,9 @@ class ContractionOpToOuterProductOpLowering
 /// This only kicks in when VectorTransformsOptions is set to Dot and
 /// the vector.contract op is a row-major matmul or matvec.
 class ContractionOpToDotLowering
-    : public OpRewritePattern<vector::ContractionOp> {
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
 
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
@@ -332,11 +391,12 @@ class ContractionOpToDotLowering
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -344,23 +404,10 @@ class ContractionOpToDotLowering
   FilterConstraintType filter;
 };
 
-/// Progressive lowering of ContractionOp.
-///
-/// One:
-///   %x = vector.contract with at least one free/batch dimension
-/// is replaced by:
-///   %a = vector.contract with one less free/batch dimension
-///   %b = vector.contract with one less free/batch dimension
-///   ..
-///   %x = combine %a %b ..
-/// until a pure contraction is reached (no free/batch dimensions),
-/// which is replaced by a dot-product.
-///
-/// This only kicks in when either VectorTransformsOptions is set
-/// to Dot or when other contraction patterns fail.
-class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
+class ContractionOpLowering
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
 
@@ -371,12 +418,13 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
   ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
                         MLIRContext *context, PatternBenefit benefit = 1,
                         FilterConstraintType constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions),
         filter(std::move(constraint)) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp op,
-                                PatternRewriter &rewriter) const override;
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override;
 
 private:
   /// Options to control the vector patterns.
@@ -634,8 +682,10 @@ struct UnrolledOuterProductGenerator
 ///
 /// This only kicks in when VectorTransformsOptions is set to OuterProduct but
 /// otherwise supports any layout permutation of the matrix-multiply.
-LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
-    vector::ContractionOp op, PatternRewriter &rewriter) const {
+FailureOr<Value>
+ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (vectorTransformOptions.vectorContractLowering !=
       vector::VectorContractLowering::OuterProduct)
     return failure();
@@ -643,43 +693,25 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
   if (failed(filter(op)))
     return failure();
 
-  // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
-  Operation *rootOp;
-  if (maskableOp.isMasked()) {
-    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
-    rootOp = maskableOp.getMaskingOp();
-  } else {
-    rootOp = op;
-  }
-
   UnrolledOuterProductGenerator e(rewriter, op);
   FailureOr<Value> matmatRes = e.matmat();
   if (succeeded(matmatRes)) {
-    rewriter.replaceOp(rootOp, *matmatRes);
-    return success();
+    return matmatRes;
   }
   FailureOr<Value> matvecRes = e.matvec();
   if (succeeded(matvecRes)) {
-    rewriter.replaceOp(rootOp, *matvecRes);
-    return success();
-  }
-  FailureOr<Value> tmatvecRes = e.tmatvec();
-  if (succeeded(tmatvecRes)) {
-    rewriter.replaceOp(rootOp, *tmatvecRes);
-    return success();
+    return matvecRes;
   }
 
-  return failure();
+  FailureOr<Value> tmatvecRes = e.tmatvec();
+  return tmatvecRes;
 }
 
-LogicalResult
-ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
-                                            PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
+  if (maskOp)
     return failure();
 
   if (failed(filter(op)))
@@ -788,15 +820,14 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
   }
   if (auto acc = op.getAcc())
     res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
-  rewriter.replaceOp(op, res);
-  return success();
+  return res;
 }
 
 /// Lower vector.contract with all size one reduction dimensions to
 /// elementwise ops when possible.
 struct ContractOpToElementwise
-    : public OpRewritePattern<vector::ContractionOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public MaybeMaskedOpRewritePattern<vector::ContractionOp> {
+  using MaybeMaskedOpRewritePattern::MaybeMaskedOpRewritePattern;
   using FilterConstraintType =
       std::function<LogicalResult(vector::ContractionOp op)>;
   static LogicalResult defaultFilter(vector::ContractionOp op) {
@@ -806,14 +837,15 @@ struct ContractOpToElementwise
       vector::VectorTransformsOptions vectorTransformOptions,
       MLIRContext *context, PatternBenefit benefit = 1,
       const FilterConstraintType &constraint = defaultFilter)
-      : OpRewritePattern<vector::ContractionOp>(context, benefit),
+      : MaybeMaskedOpRewritePattern<vector::ContractionOp>(context, benefit),
         vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
 
-  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
+                            MaskingOpInterface maskOp,
+                            PatternRewriter &rewriter) const override {
     // TODO: Support vector.mask.
-    auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
-    if (maskableOp.isMasked())
+    if (maskOp)
       return failure();
 
     if (failed(filter(contractOp)))
@@ -903,8 +935,10 @@ struct ContractOpToElementwise
     std::optional<Value> result =
         createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
                               contractOp.getKind(), rewriter, isInt);
-    rewriter.replaceOp(contractOp, {*result});
-    return success();
+    if (result)
+      return *result;
+
+    return failure();
   }
 
 private:
@@ -930,9 +964,9 @@ struct ContractOpToElementwise
 // TODO: break down into transpose/reshape/cast ops
 //               when they become available to avoid code dup
 // TODO: investigate lowering order impact on performance
-LogicalResult
-ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                       PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rewriter) const {
   if (failed(filter(op)))
     return failure();
 
@@ -951,29 +985,36 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 
   // TODO: implement benefits, cost models.
   MLIRContext *ctx = op.getContext();
+
   ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
-  if (succeeded(pat1.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal1 =
+      pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal1))
+    return newVal1;
+
   ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
-  if (succeeded(pat2.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal2 =
+      pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal2))
+    return newVal2;
+
   ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
-  if (succeeded(pat3.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal3 =
+      pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal3))
+    return newVal3;
+
   ContractOpToElementwise pat4(vectorTransformOptions, ctx);
-  if (succeeded(pat4.matchAndRewrite(op, rewriter)))
-    return success();
+  FailureOr<Value> newVal4 =
+      pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
+  if (!failed(newVal4))
+    return newVal4;
 
   // Vector mask setup.
-  OpBuilder::InsertionGuard guard(rewriter);
-  Operation *rootOp = op;
-  Value mask;
-  if (op.isMasked()) {
-    rewriter.setInsertionPoint(op.getMaskingOp());
-    rootOp = op.getMaskingOp();
-    mask = op.getMaskingOp().getMask();
-  }
 
+  Value mask;
+  if (maskOp)
+    mask = maskOp.getMask();
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
   if (!batchDimMap.empty()) {
@@ -982,8 +1023,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
     auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
+    return newOp;
   }
 
   // Collect contracting dimensions.
@@ -1003,8 +1043,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
       auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
+      return newOp;
     }
   }
 
@@ -1015,8 +1054,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
       auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(rootOp, *newOp);
-      return success();
+      return newOp;
     }
   }
 
@@ -1025,8 +1063,7 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
     auto newOp = lowerReduction(rewriter, op, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(rootOp, *newOp);
-    return success();
+    return newOp;
   }
 
   return failure();
@@ -1291,12 +1328,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 /// This only kicks in when VectorTransformsOptions is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
 /// row-major matrix multiply.
-LogicalResult
-ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
-                                                 PatternRewriter &rew) const {
+FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
+    vector::ContractionOp op, MaskingOpInterface maskOp,
+    PatternRewriter &rew) const {
   // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
+  if (maskOp)
     return failure();
 
   if (vectorTransformOptions.vectorContractLowering !=
@@ -1379,8 +1415,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
           : static_cast<Value>(
                 rew.create<arith::AddFOp>(loc, op.getAcc(), mul));
 
-  rew.replaceOp(op, res);
-  return success();
+  return res;
 }
 } // namespace
 

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but perhaps we want some feedback from @dcaballe

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! A few comments

/// * the source Op when mask _is not_ present,
/// * the source Op *and* the mask Op when mask _is_ present.
template <class SourceOp>
struct MaybeMaskedOpRewritePattern : OpRewritePattern<SourceOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main concern I've been struggling with for a while: this pattern matches the nested op and the rewrite would replace its parent. I'm not quite sure pattern rewrite allows that even though it works for this cases. We can make this match against the vector mask op but then we would need another pattern to match against the source op for the unmasked cases. I couldn't find an elegant way to provide this functionality but maybe you can :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main concern I've been struggling with for a while: this pattern matches the nested op and the rewrite would replace its parent. I'm not quite sure pattern rewrite allows that even though it works for this cases

I wasn't sure myself so I asked:

I guess that this should be fine then?

I couldn't find an elegant way to provide this functionality but maybe you can :)

No better ideas just yet :/

Comment on lines -661 to +655
return success();
return matmatRes;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we replace the root op here? That should preserve the same structure as for other patterns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two cases though:

  • vector.contract --> here we should replace vector.contract
  • vector.mask {vector.contract} --> here we should replace vector.mask (no need to "replace" vector contract)

This can be implemented, but then you will require logic like this in every pattern:

  OpBuilder::InsertionGuard guard(rewriter);
  auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
  Operation *rootOp;
  if (maskableOp.isMasked()) {
    rewriter.setInsertionPoint(maskableOp.getMaskingOp());
    rootOp = maskableOp.getMaskingOp();
  } else {
    rootOp = op;
  }

That's something that I am removing and instead moving to the base class, MaskableOpRewritePattern. But it also means I can't "replace" anymore as the logic to decide "what" to replace has been removed.

Am I overthinking this? 🤔

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Adds a generic pattern rewrite for maskable Ops that we would like to
work for both masked and un-masked cases, e.g. for both `vector.mask
{vector.contract}` and `vector.contract` (this is a very contrived
example - just to demonstrate the idea).

This helps to reduce code-duplication and standardise how we implement
such patterns.

Fixes llvm#78787
Rename the pattern, simplify code, restore docs
@banach-space banach-space force-pushed the andrzej/refactor_masked_rewrites branch from 158e6dd to 3ae7afc Compare March 19, 2024 22:29
@banach-space banach-space merged commit b7324b6 into llvm:main Mar 20, 2024
@banach-space banach-space deleted the andrzej/refactor_masked_rewrites branch March 20, 2024 21:04
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 21, 2024
Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of llvm#83827
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 22, 2024
Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of llvm#83827
banach-space added a commit that referenced this pull request Mar 22, 2024
…Dim (#81906)

Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of #83827
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
Adds a generic pattern rewrite for maskable Ops, `MaskableOpRewritePattern`,
that will work for both masked and un-masked cases, e.g. for both:

* `vector.mask {vector.contract}` (masked), and
* `vector.contract` (not masked).

This helps to reduce code-duplication and standardise how we implement such
patterns.

Fixes llvm#78787
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…Dim (llvm#81906)

Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of llvm#83827
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] Special handling for pattern rewrite on maskable op
4 participants