-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Exclude masked ops in VectorDropLeadUnitDim #76468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector Author: Jerry Wu (pzread) ChangesDon't insert cast ops for ops in According to the vector masking RFC: https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964 I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : ) Full diff: https://github.com/llvm/llvm-project/pull/76468.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607..65517295aa72d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
+ // Not supported masked op yet.
+ if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
+ return failure();
// TODO: support 0-d corner case.
if (read.getTransferRank() == 0)
return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
+ // Not supported masked op yet.
+ if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
+ return failure();
// TODO: support 0-d corner case.
if (write.getTransferRank() == 0)
return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
+ // Not supported masked op yet.
+ if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
+ return failure();
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 71dffca8f14da5..f601be04168144 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
return %0: vector<1x1x2x16xf32>
}
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
+// CHECK: %[[MASK:.+]] = vector.constant_mask
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
+// CHECK: return %[[RET]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+ affine_map<(l, i, j, k) -> (l, i, k)>,
+ affine_map<(l, i, j, k) -> (l, k, j)>,
+ affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+ %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+ %0 = vector.mask %mask {
+ vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+ } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+ return %0 : vector<1x16x16xf32>
+}
+
// -----
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
// -----
+// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
+// CHECK: %[[MASK:.+]] = vector.constant_mask
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
+// CHECK: return %[[RET]] : vector<1x4xf16>
+func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0. : f16
+ %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+ %ret = vector.mask %mask {
+ vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
+ } : vector<1x4xi1> -> vector<1x4xf16>
+ return %ret: vector<1x4xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
// -----
+// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
+// CHECK: %[[MASK:.+]] = vector.constant_mask
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK: vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
+// CHECK: return
+func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+ vector.mask %mask {
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
+ } : vector<1x4xi1>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
|
@llvm/pr-subscribers-mlir Author: Jerry Wu (pzread) ChangesDon't insert cast ops for ops in According to the vector masking RFC: https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964 I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : ) Full diff: https://github.com/llvm/llvm-project/pull/76468.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 84294e4552a607..65517295aa72d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
+ // Not supported masked op yet.
+ if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
+ return failure();
// TODO: support 0-d corner case.
if (read.getTransferRank() == 0)
return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
+ // Not supported masked op yet.
+ if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
+ return failure();
// TODO: support 0-d corner case.
if (write.getTransferRank() == 0)
return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
+ // Not supported masked op yet.
+ if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
+ return failure();
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index 71dffca8f14da5..f601be04168144 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
return %0: vector<1x1x2x16xf32>
}
+// -----
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
+
+// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
+// CHECK: %[[MASK:.+]] = vector.constant_mask
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
+// CHECK: return %[[RET]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+ affine_map<(l, i, j, k) -> (l, i, k)>,
+ affine_map<(l, i, j, k) -> (l, k, j)>,
+ affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+ %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+ %0 = vector.mask %mask {
+ vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+ } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+ return %0 : vector<1x16x16xf32>
+}
+
// -----
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
// -----
+// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
+// CHECK: %[[MASK:.+]] = vector.constant_mask
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
+// CHECK: return %[[RET]] : vector<1x4xf16>
+func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
+ %c0 = arith.constant 0 : index
+ %f0 = arith.constant 0. : f16
+ %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+ %ret = vector.mask %mask {
+ vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
+ } : vector<1x4xi1> -> vector<1x4xf16>
+ return %ret: vector<1x4xf16>
+}
+
+// -----
+
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
// -----
+// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
+// CHECK: %[[MASK:.+]] = vector.constant_mask
+// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
+// CHECK: vector.mask %[[CASTED_MASK]] {
+// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
+// CHECK: return
+func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
+ %c0 = arith.constant 0 : index
+ %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
+ vector.mask %mask {
+ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
+ } : vector<1x4xi1>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
|
@dcaballe kindly ping : ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - this is consistent with the docs, thanks! LGTM
I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : )
In my view, we should try to make sure that we use every opportunity possible to drop the unit dims at higher levels of abstraction - to avoid the problem to begin with. Further down things become tricky.
We can implement a pattern rewrite for MaskOp that inherits from a base class like the
This should automatically deal with the insertion point issue of the mask op. |
0a3c888
to
5b534f4
Compare
I filed #78787 and added TODO for the follow-up works |
Don't insert cast ops for ops in `vector.mask` region in `VectorDropLeadUnitDim`.
Don't insert cast ops for ops in `vector.mask` region in `VectorDropLeadUnitDim`.
@dcaballe I was trying to fix this, and had a doubt - are you suggesting pattern rewrite to convert MaskOp directly to llvm before this optimization to prevent insertion of cast ops into maskop's region? Or did I misunderstand this comment? |
No, I meant that you can create a |
Don't insert cast ops for ops in
vector.mask
region inVectorDropLeadUnitDim
.According to the vector masking RFC: https://discourse.llvm.org/t/rfc-vector-masking-representation-in-mlir/64964
vector.mask
op doesn't support multiple ops in its region. Therefore, inVectorDropLeadUnitDim
we can't directly insert cast ops in the region. This change temporarily skips such cases as a workaround.I'm not quite sure of the complete solution for this issue. But the same problem might also happen to other vector transformations which require to insert ops before/after a maskable vector op. Any feedback will be very helpful : )