Skip to content

Commit 2731d26

Browse files
authored
[mlir][vector] Support more mask types in foldTransferFullMask() (#96761)
Using the existing `getMaskFormat()` this can be extended to support `arith.constant` masks.
1 parent d4e9ba5 commit 2731d26

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4172,11 +4172,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
41724172
if (!mask)
41734173
return failure();
41744174

4175-
auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
4176-
if (!constantMask)
4177-
return failure();
4178-
4179-
if (!constantMask.isAllOnesMask())
4175+
if (getMaskFormat(mask) != MaskFormat::AllTrue)
41804176
return failure();
41814177

41824178
op.getMaskMutable().clear();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -868,14 +868,16 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
868868
// -----
869869

870870
// CHECK-LABEL: fold_vector_transfer_masks
871-
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
871+
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
872872
// CHECK: %[[C0:.+]] = arith.constant 0 : index
873873
%c0 = arith.constant 0 : index
874874
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
875875
%f0 = arith.constant 0.0 : f32
876876

877877
%mask = vector.constant_mask [8, 4] : vector<8x4xi1>
878878

879+
%arith_all_true_mask = arith.constant dense<true> : vector<4x[4]xi1>
880+
879881
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
880882
%1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
881883
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
@@ -884,8 +886,14 @@ func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>)
884886
vector.transfer_write %1, %A[%c0, %c0], %mask
885887
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
886888

889+
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] :
890+
%2 = vector.transfer_read %A[%c0, %c0], %f0, %arith_all_true_mask : memref<?x?xf32>, vector<4x[4]xf32>
891+
892+
// CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] :
893+
vector.transfer_write %2, %A[%c0, %c0], %arith_all_true_mask : vector<4x[4]xf32>, memref<?x?xf32>
894+
887895
// CHECK: return
888-
return %1 : vector<4x8xf32>
896+
return %1, %2 : vector<4x8xf32>, vector<4x[4]xf32>
889897
}
890898

891899
// -----

0 commit comments

Comments
 (0)