-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Support more mask types in foldTransferFullMask() #96761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Using the existing `getMaskFormat()` this can be extended to support `arith.constant` masks.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesUsing the existing Full diff: https://github.com/llvm/llvm-project/pull/96761.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6734c80f2760d..149723f51cc12 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4172,11 +4172,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
if (!mask)
return failure();
- auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
- if (!constantMask)
- return failure();
-
- if (!constantMask.isAllOnesMask())
+ if (getMaskFormat(mask) != MaskFormat::AllTrue)
return failure();
op.getMaskMutable().clear();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8181f1a8c5d13..ecd49df3b2141 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -868,7 +868,7 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
// -----
// CHECK-LABEL: fold_vector_transfer_masks
-func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
+func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
@@ -876,6 +876,8 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
%mask = vector.constant_mask [8, 4] : vector<8x4xi1>
+ %mask_splat = arith.constant dense<true> : vector<4x[4]xi1>
+
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
%1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
vector.transfer_write %1, %A[%c0, %c0], %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
+ // CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
+ %2 = vector.transfer_read %A[%c0, %c0], %f0, %mask_splat : memref<?x?xf32>, vector<4x[4]xf32>
+
+ // CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
+ vector.transfer_write %2, %A[%c0, %c0], %mask_splat : vector<4x[4]xf32>, memref<?x?xf32>
+
// CHECK: return
- return %1 : vector<4x8xf32>
+ return %1, %2 : vector<4x8xf32>, vector<4x[4]xf32>
}
// -----
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! LGTM just one minor nit, cheers
…m#96761) Using the existing `getMaskFormat()` this can be extended to support `arith.constant` masks.
…m#96761) Using the existing `getMaskFormat()` this can be extended to support `arith.constant` masks.
Using the existing
getMaskFormat()
this can be extended to supportarith.constant
masks.