Skip to content

Commit f6d4872

Browse files
authored
[mlir][vector] Add folders for full constant transfer masks (llvm#71676)
When the mask bounds of a `vector.constant_mask` exactly equal the shape of the vector, any transfer op consuming that mask will be unaffected by it. Drop the mask in such cases.
1 parent ef6d187 commit f6d4872

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,6 +2276,16 @@ def Vector_ConstantMaskOp :
22762276
```
22772277
}];
22782278

2279+
let extraClassDeclaration = [{
2280+
/// Return the result type of this op.
2281+
VectorType getVectorType() {
2282+
return cast<VectorType>(getOperation()->getResultTypes()[0]);
2283+
}
2284+
2285+
/// Return whether the mask is a uniform vector of `1`s.
2286+
bool isFullMask();
2287+
}];
2288+
22792289
let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)";
22802290
let hasVerifier = 1;
22812291
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3937,6 +3937,23 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
39373937
return success();
39383938
}
39393939

3940+
template <typename TransferOp>
3941+
static LogicalResult foldTransferFullMask(TransferOp op) {
3942+
auto mask = op.getMask();
3943+
if (!mask)
3944+
return failure();
3945+
3946+
auto constantMask = mask.template getDefiningOp<vector::ConstantMaskOp>();
3947+
if (!constantMask)
3948+
return failure();
3949+
3950+
if (!constantMask.isFullMask())
3951+
return failure();
3952+
3953+
op.getMaskMutable().clear();
3954+
return success();
3955+
}
3956+
39403957
/// ```
39413958
/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
39423959
/// : vector<1x4xf32>, tensor<4x4xf32>
@@ -3969,6 +3986,8 @@ OpFoldResult TransferReadOp::fold(FoldAdaptor) {
39693986
/// transfer_read(memrefcast) -> transfer_read
39703987
if (succeeded(foldTransferInBoundsAttribute(*this)))
39713988
return getResult();
3989+
if (succeeded(foldTransferFullMask(*this)))
3990+
return getResult();
39723991
if (succeeded(memref::foldMemRefCast(*this)))
39733992
return getResult();
39743993
if (succeeded(tensor::foldTensorCast(*this)))
@@ -4334,6 +4353,8 @@ LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
43344353
return success();
43354354
if (succeeded(foldTransferInBoundsAttribute(*this)))
43364355
return success();
4356+
if (succeeded(foldTransferFullMask(*this)))
4357+
return success();
43374358
return memref::foldMemRefCast(*this);
43384359
}
43394360

@@ -5601,6 +5622,22 @@ LogicalResult ConstantMaskOp::verify() {
56015622
return success();
56025623
}
56035624

5625+
bool ConstantMaskOp::isFullMask() {
5626+
auto resultType = getVectorType();
5627+
// Check the corner case of 0-D vectors first.
5628+
if (resultType.getRank() == 0) {
5629+
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5630+
return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5631+
}
5632+
for (const auto [resultSize, intAttr] :
5633+
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5634+
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5635+
if (maskDimSize < resultSize)
5636+
return false;
5637+
}
5638+
return true;
5639+
}
5640+
56045641
//===----------------------------------------------------------------------===//
56055642
// CreateMaskOp
56065643
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,29 @@ func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0: vector<3x4xf32>)
854854

855855
// -----
856856

857+
// CHECK-LABEL: fold_vector_transfer_masks
858+
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>) {
859+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
860+
%c0 = arith.constant 0 : index
861+
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
862+
%f0 = arith.constant 0.0 : f32
863+
864+
%mask = vector.constant_mask [8, 4] : vector<8x4xi1>
865+
866+
// CHECK: vector.transfer_read %{{.*}}, %[[F0]] {permutation_map
867+
%1 = vector.transfer_read %A[%c0, %c0], %f0, %mask
868+
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : memref<?x?xf32>, vector<4x8xf32>
869+
870+
// CHECK: vector.transfer_write {{.*}}[%[[C0]], %[[C0]]] {permutation_map
871+
vector.transfer_write %1, %A[%c0, %c0], %mask
872+
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<4x8xf32>, memref<?x?xf32>
873+
874+
// CHECK: return
875+
return %1 : vector<4x8xf32>
876+
}
877+
878+
// -----
879+
857880
// CHECK-LABEL: fold_vector_transfers
858881
func.func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
859882
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)