Skip to content

[mlir][vector] Add support for masks in castAwayContractionLeadingOneDim #81906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 22, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Feb 15, 2024

Updates castAwayContractionLeadingOneDim to inherit from
MaskableOpRewritePattern so that this pattern can support masking.

Builds on top of #83827

@llvmbot
Copy link
Member

llvmbot commented Feb 15, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Partial fix for #78787


Full diff: https://github.com/llvm/llvm-project/pull/81906.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+34-11)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+74-30)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..f7f2b934056185 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -332,9 +332,12 @@ struct CastAwayTransferWriteLeadingOneDim
 LogicalResult
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                RewriterBase &rewriter) {
-  // TODO(#78787): Not supported masked op yet.
-  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
-    return failure();
+  // Specifically for masked Ops for which we need to update the insertion
+  // point
+  PatternRewriter::InsertionGuard guard(rewriter);
+
+  auto isMasked =
+      cast<MaskableOpInterface>(contractOp.getOperation()).isMasked();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();
@@ -346,6 +349,12 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
   // greedily to drop more.
   int64_t dropDim = 1;
 
+  if (isMasked) {
+    // Update the insertion point to avoid adding more ops to the vector.mask
+    // region corresponding to `mask`
+    rewriter.setInsertionPointAfter(contractOp->getParentOp());
+  }
+
   auto oldIndexingMaps = contractOp.getIndexingMapsArray();
   SmallVector<AffineMap> newIndexingMaps;
 
@@ -368,6 +377,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
   SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
                                  contractOp.getAcc()};
   SmallVector<Value> newOperands;
+  auto loc = contractOp.getLoc();
 
   for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
     // Check if the dim to be dropped exists as a leading dim in the operand
@@ -405,7 +415,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
         map = AffineMap::get(map.getNumDims(), 0, transposeResults,
                              contractOp.getContext());
         operands[it.index()] = rewriter.create<vector::TransposeOp>(
-            contractOp.getLoc(), operands[it.index()], perm);
+            loc, operands[it.index()], perm);
       }
     }
     // We have taken care to have the dim to be dropped be
@@ -429,17 +439,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
     // Extract if its a valid extraction, otherwise use the operand
     // without extraction.
     newOperands.push_back(
-        validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
-                                                          operands[it.index()],
-                                                          splatZero(dropDim))
+        validExtract ? rewriter.create<vector::ExtractOp>(
+                           loc, operands[it.index()], splatZero(dropDim))
                      : operands[it.index()]);
   }
-  auto newContractOp = rewriter.create<vector::ContractionOp>(
-      contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+  Operation *newContractOp = rewriter.create<vector::ContractionOp>(
+      loc, newOperands[0], newOperands[1], newOperands[2],
       rewriter.getAffineMapArrayAttr(newIndexingMaps),
       rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
-  rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-      contractOp, contractOp->getResultTypes()[0], newContractOp);
+
+  if (isMasked) {
+    auto mask = contractOp.getMaskingOp();
+    auto newMask = rewriter.create<vector::ExtractOp>(loc, mask.getMask(),
+                                                      splatZero(dropDim));
+
+    newContractOp =
+        mlir::vector::maskOperation(rewriter, newContractOp, newMask);
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        mask, contractOp->getResultTypes()[0], newContractOp->getResults()[0]);
+  } else {
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+        contractOp, contractOp->getResultTypes()[0],
+        newContractOp->getResults()[0]);
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..4ba51c5953d13c 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
 }
 
 // -----
+// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_const_mask
+// CHECK:           %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} 
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK:           return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
+// -----
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_mask
+// CHECK:           %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK:           %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK:           %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
+// CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
+// CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> 
+// CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_mask(
+  %arg0: vector<1x16x8xf32>,
+  %arg1: vector<1x8x16xf32>,
+  %arg2: vector<1x16x16xf32>,
+  %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0: vector<1x16x16xf32>
+}
+
+// -----
+
 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
-// -----
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-
-// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
-// CHECK:      %[[MASK:.+]] = vector.constant_mask
-// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
-// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
-// CHECK-SAME:   vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
-// CHECK:      return %[[RET]] : vector<1x16x16xf32>
-
-#contraction_accesses0 = [
-  affine_map<(l, i, j, k) -> (l, i, k)>,
-  affine_map<(l, i, j, k) -> (l, k, j)>,
-  affine_map<(l, i, j, k) -> (l, i, j)>
-]
-#contraction_trait0 = {
-  indexing_maps = #contraction_accesses0,
-  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-}
-
-func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
-  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
-  %0 = vector.mask %mask {
-    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
-  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
-  return %0 : vector<1x16x16xf32>
-}
 
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims

@@ -346,6 +349,12 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// greedily to drop more.
int64_t dropDim = 1;

if (isMasked) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some snippets under the comment Vector mask setup or similar that doing something like this. I think we should perhaps create a utility for those and define some kind of canonical form to do this at pattern rewrite level. For conversions, the goal when we introduced the mask op was to try to reduce this kind of conditional code and provide some infra to make it as transparent as possible, hence the mask conversion pattern class that we use for LLVM conversion. Let's brainstorm a bit about the different options and see what can be improved in this regard.

Copy link
Contributor

@tanmaysachan tanmaysachan Mar 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcaballe I was implementing a general purpose implementation of this, something like tanmaysachan@242f901

Does this look okay to you? Can refactor for common code between the masked vs non-masked rewrite classes to be a function, and the pattern rewrites handle the entry points.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, returning to this after two weeks of being OOO.

@tanmaysachan , from a quick scan, you have adopted VectorMaskOpConversionBase which matches vector.mask. IIUC, that's going to be insufficient in this case - we need something that would work both for:

  • vector.mask {vector.contract}, and
  • vector.contract.

VectorMaskOpConversionBase would only work for the first case. Perhaps @dcaballe had something else in mind, but IMHO we need another "base" class to accomodate for that. This a blocker for me, so I'm proposing one here (*):

I think we should perhaps create a utility for those and define some kind of canonical form to do this at pattern rewrite level.

Does #83827 make sense? Happy to try something else :) Also, regardless of the long-term approach that we take here, would you be OK with me landing this to unblock linalg.mmt4d investigation?

(*) Apologies @tanmaysachan if you are also actively working on this, but this is quite urgent for me. Mindful of delays due to time difference, I decided to go ahead and draft something quickly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I was unsure about what the go ahead way for this was.

@banach-space
Copy link
Contributor Author

Rebased on top of #83827

@banach-space banach-space force-pushed the andrzej/add_mask_support branch from 8d1a121 to 1e9c7a1 Compare March 16, 2024 18:58
banach-space added a commit to banach-space/llvm-project that referenced this pull request Mar 16, 2024
Stopped while trying to get the added example to compile using masks.
* Depends on llvm#81906, which solves
one issue.
* Not sure what happens after that.
* Compare against non-masked version and "reverse engineer" a solution/pipeline.
@banach-space banach-space force-pushed the andrzej/add_mask_support branch 2 times, most recently from 8d650b3 to f6f3665 Compare March 21, 2024 21:39
@banach-space banach-space changed the title [mlir][Vector] Add support for masks in castAwayContractionLeadingOneDim [mlir][vector] Add support for masks in castAwayContractionLeadingOneDim Mar 21, 2024
@banach-space
Copy link
Contributor Author

There are some snippets under the comment Vector mask setup or similar that doing something like this. I think we should perhaps create a utility for those and define some kind of canonical form to do this at pattern rewrite level.

@dcaballe Done :) I have rewritten this to build on top of #83827, wdyt?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@banach-space banach-space force-pushed the andrzej/add_mask_support branch from f6f3665 to 01d87b5 Compare March 22, 2024 08:16
Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of llvm#83827
@banach-space banach-space force-pushed the andrzej/add_mask_support branch from 01d87b5 to 734fbd3 Compare March 22, 2024 08:20
@banach-space banach-space merged commit 5f1b2cf into llvm:main Mar 22, 2024
@banach-space banach-space deleted the andrzej/add_mask_support branch March 22, 2024 09:37
chencha3 pushed a commit to chencha3/llvm-project that referenced this pull request Mar 23, 2024
…Dim (llvm#81906)

Updates `castAwayContractionLeadingOneDim` to inherit from
`MaskableOpRewritePattern` so that this pattern can support masking.

Builds on top of llvm#83827
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants