Skip to content

[mlir] Exclude masked ops in VectorDropLeadUnitDim #76468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 21, 2024

Conversation

pzread
Copy link
Member

@pzread pzread commented Dec 27, 2023

Don't insert cast ops for ops in vector.mask region in VectorDropLeadUnitDim.

According to the vector masking RFC: https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964 vector.mask op doesn't support multiple ops in its region. Therefore, in VectorDropLeadUnitDim we can't directly insert cast ops in the region. This change temporarily skips such cases as a workaround.

I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : )

@pzread pzread changed the title [mlir] Exclude masked ops in vector transformation [mlir] Exclude masked ops in VectorDropLeadUnitDim Dec 27, 2023
@pzread pzread requested a review from dcaballe December 27, 2023 23:16
@pzread pzread marked this pull request as ready for review December 27, 2023 23:16
@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2023

@llvm/pr-subscribers-mlir-vector

Author: Jerry Wu (pzread)

Changes

Don't insert cast ops for ops in vector.mask region in VectorDropLeadUnitDim.

According to the vector masking RFC: https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964 vector.mask op doesn't support multiple ops in its region. Therefore, in VectorDropLeadUnitDim we can't directly insert cast ops in the region. This change temporarily skips such cases as a workaround.

I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : )


Full diff: https://github.com/llvm/llvm-project/pull/76468.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+9)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+66)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607..65517295aa72d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (read.getTransferRank() == 0)
       return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (write.getTransferRank() == 0)
       return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
 LogicalResult
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                RewriterBase &rewriter) {
+  // Not supported masked op yet.
+  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
+    return failure();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 71dffca8f14da5..f601be04168144 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
+// CHECK:      return %[[RET]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
 func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
 
 // -----
 
+// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
+// CHECK:      return %[[RET]] : vector<1x4xf16>
+func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0. : f16
+  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+  %ret = vector.mask %mask {
+    vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
+  } : vector<1x4xi1> -> vector<1x4xf16>
+  return %ret: vector<1x4xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
 
 // -----
 
+// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
+// CHECK:      return
+func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
+  %c0 = arith.constant 0 : index
+  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+  vector.mask %mask {
+    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
+  } : vector<1x4xi1>
+  return
+}
+
+// -----
+
 // CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
 // CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
 func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {

@llvmbot
Copy link
Member

llvmbot commented Dec 27, 2023

@llvm/pr-subscribers-mlir

Author: Jerry Wu (pzread)

Changes

Don't insert cast ops for ops in vector.mask region in VectorDropLeadUnitDim.

According to the vector masking RFC: https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964 vector.mask op doesn't support multiple ops in its region. Therefore, in VectorDropLeadUnitDim we can't directly insert cast ops in the region. This change temporarily skips such cases as a workaround.

I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : )


Full diff: https://github.com/llvm/llvm-project/pull/76468.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp (+9)
  • (modified) mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir (+66)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607..65517295aa72d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferReadOp read,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (read.getTransferRank() == 0)
       return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
 
   LogicalResult matchAndRewrite(vector::TransferWriteOp write,
                                 PatternRewriter &rewriter) const override {
+    // Not supported masked op yet.
+    if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
+      return failure();
     // TODO: support 0-d corner case.
     if (write.getTransferRank() == 0)
       return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
 LogicalResult
 mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
                                                RewriterBase &rewriter) {
+  // Not supported masked op yet.
+  if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
+    return failure();
   VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
   if (oldAccType == nullptr)
     return failure();
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 71dffca8f14da5..f601be04168144 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
   return %0: vector<1x1x2x16xf32>
 }
 
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
+// CHECK:      return %[[RET]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+  affine_map<(l, i, j, k) -> (l, i, k)>,
+  affine_map<(l, i, j, k) -> (l, k, j)>,
+  affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+  indexing_maps = #contraction_accesses0,
+  iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+  %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+  %0 = vector.mask %mask {
+    vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+  } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+  return %0 : vector<1x16x16xf32>
+}
+
 // -----
 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
 func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
 
 // -----
 
+// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
+// CHECK:      return %[[RET]] : vector<1x4xf16>
+func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
+  %c0 = arith.constant 0 : index
+  %f0 = arith.constant 0. : f16
+  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+  %ret = vector.mask %mask {
+    vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
+  } : vector<1x4xi1> -> vector<1x4xf16>
+  return %ret: vector<1x4xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
 
 // -----
 
+// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
+// CHECK:      %[[MASK:.+]] = vector.constant_mask
+// CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK:      vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME:   vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
+// CHECK:      return
+func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
+  %c0 = arith.constant 0 : index
+  %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+  vector.mask %mask {
+    vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
+  } : vector<1x4xi1>
+  return
+}
+
+// -----
+
 // CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
 // CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
 func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {

@pzread
Copy link
Member Author

pzread commented Jan 3, 2024

@dcaballe kindly ping : )

@banach-space banach-space self-requested a review January 4, 2024 18:17
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - this is consistent with the docs, thanks! LGTM

I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : )

In my view, we should try to make sure that we use every opportunity possible to drop the unit dims at higher levels of abstraction - to avoid the problem to begin with. Further down things become tricky.

@dcaballe
Copy link
Contributor

We can implement a pattern rewrite for MaskOp that inherits from a base class like the VectorMaskOpConversionBase here:

This should automatically deal with the insertion point issue of the mask op.

@pzread pzread force-pushed the handle-vec-mask-in-transform branch from 0a3c888 to 5b534f4 Compare January 19, 2024 21:22
@pzread
Copy link
Member Author

pzread commented Jan 19, 2024

We can implement a pattern rewrite for MaskOp that inherits from a base class like the VectorMaskOpConversionBase here:

This should automatically deal with the insertion point issue of the mask op.

I filed #78787 and added TODO for the follow-up works

@pzread pzread merged commit dedc7d4 into llvm:main Jan 21, 2024
hanhanW pushed a commit to iree-org/llvm-project that referenced this pull request Jan 22, 2024
Don't insert cast ops for ops in `vector.mask` region in
`VectorDropLeadUnitDim`.
hanhanW pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
Don't insert cast ops for ops in `vector.mask` region in
`VectorDropLeadUnitDim`.
@tanmaysachan
Copy link
Contributor

tanmaysachan commented Jan 27, 2024

We can implement a pattern rewrite for MaskOp that inherits from a base class like the VectorMaskOpConversionBase here:

This should automatically deal with the insertion point issue of the mask op.

@dcaballe I was trying to fix this, and had a doubt - are you suggesting pattern rewrite to convert MaskOp directly to llvm before this optimization to prevent insertion of cast ops into maskop's region? Or did I misunderstand this comment?

@dcaballe
Copy link
Contributor

dcaballe commented Feb 1, 2024

No, I meant that you can create a VectorMaskOpConversionBase as we have for the llvm conversion and then write your conversion pattern for the VectorMaskOp inheriting from it, as we also do for llvm;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants