Skip to content

[mlir][vector] Fix patterns for dropping leading unit dims from masks #73525

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ getAsConstantIndexOps(ArrayRef<Value> values);
// Vector Masking Utilities
//===----------------------------------------------------------------------===//

/// Infers the mask type for a transfer op given its vector type and
/// permutation map. The mask in a transfer op operation applies to the
/// tensor/buffer part of it and its type should match the vector shape
/// *before* any permutation or broadcasting. For example,
///
/// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)>
///
/// Has inferred mask type:
///
/// maskType = vector<2x1xi1>
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);

/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
/// as masked operation.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
Expand Down
8 changes: 2 additions & 6 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
p << " : " << getShapedType() << ", " << getVectorType();
}

/// Infers the mask type for a transfer op given its vector type and
/// permutation map. The mask in a transfer op operation applies to the
/// tensor/buffer part of it and its type should match the vector shape
/// *before* any permutation or broadcasting.
static VectorType inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) {
VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
AffineMap permMap) {
auto i1Type = IntegerType::get(permMap.getContext(), 1);
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
assert(invPermMap && "Inversed permutation map couldn't be computed");
Expand Down
32 changes: 23 additions & 9 deletions mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
}
};

static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
VectorType newType, AffineMap newMap,
VectorType oldMaskType) {
// Infer the type of the new mask from the new map.
VectorType newMaskType = inferTransferOpMaskType(newType, newMap);

// If the new mask is broadcastable to the old result type, we can safely
// use a `vector.extract` to get the new mask. Otherwise the best we can
// do is shape cast.
if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
BroadcastableToResult::Success) {
int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
}
return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
}

// Turns vector.transfer_read on vector with leading 1 dimensions into
// vector.shape_cast followed by vector.transfer_read on vector without leading
// 1 dimensions.
Expand Down Expand Up @@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim

Value mask = Value();
if (read.getMask()) {
// The mask shape must always match the shape of the written vector, so we
// can safely use the same extraction indices.
int64_t dropDim = oldType.getRank() - newType.getRank();
mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
splatZero(dropDim));
VectorType maskType = read.getMaskType();
mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
newType, newMap, maskType);
}

auto newRead = rewriter.create<vector::TransferReadOp>(
Expand Down Expand Up @@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
write.getLoc(), write.getVector(), splatZero(dropDim));

if (write.getMask()) {
// The mask shape must always match the shape of the written vector, so we
// can safely use the same extraction indices.
auto newMask = rewriter.create<vector::ExtractOp>(
write.getLoc(), write.getMask(), splatZero(dropDim));
VectorType maskType = write.getMaskType();
Value newMask = dropUnitDimsFromMask(
rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
write, newVector, write.getSource(), write.getIndices(),
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
return %0: vector<1x1xf16>
}

// -----

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
%f0 = arith.constant 0. : f16
// CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
// CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
// CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
// CHECK: return %[[CAST]]
return %0: vector<1x1x4xf16>
}

// -----

// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
Expand Down Expand Up @@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
return
}

// -----

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
// CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
// CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
// CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>

vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
return
}

// -----

// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
func.func @cast_away_elementwise_leading_one_dims(
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
Expand Down