Skip to content

Commit 6e8f7d5

Browse files
authored
[mlir][vector] Fix patterns for dropping leading unit dims from masks (#73525)
Previously the pattern only worked when the permutation map was a minor identity. Infer the new mask type from the new transfer map after dropping leading unit dims.
1 parent 1bdb2e8 commit 6e8f7d5

File tree

4 files changed

+77
-15
lines changed

4 files changed

+77
-15
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,18 @@ getAsConstantIndexOps(ArrayRef<Value> values);
160160
// Vector Masking Utilities
161161
//===----------------------------------------------------------------------===//
162162

163+
/// Infers the mask type for a transfer op given its vector type and
164+
/// permutation map. The mask in a transfer op operation applies to the
165+
/// tensor/buffer part of it and its type should match the vector shape
166+
/// *before* any permutation or broadcasting. For example,
167+
///
168+
/// vecType = vector<1x2x3xf32>, permMap = affine_map<(d0, d1, d2) -> (d1, d0)>
169+
///
170+
/// Has inferred mask type:
171+
///
172+
/// maskType = vector<2x1xi1>
173+
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap);
174+
163175
/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
164176
/// as masked operation.
165177
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3754,12 +3754,8 @@ void TransferReadOp::print(OpAsmPrinter &p) {
37543754
p << " : " << getShapedType() << ", " << getVectorType();
37553755
}
37563756

3757-
/// Infers the mask type for a transfer op given its vector type and
3758-
/// permutation map. The mask in a transfer op operation applies to the
3759-
/// tensor/buffer part of it and its type should match the vector shape
3760-
/// *before* any permutation or broadcasting.
3761-
static VectorType inferTransferOpMaskType(VectorType vecType,
3762-
AffineMap permMap) {
3757+
VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
3758+
AffineMap permMap) {
37633759
auto i1Type = IntegerType::get(permMap.getContext(), 1);
37643760
AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
37653761
assert(invPermMap && "Inversed permutation map couldn't be computed");

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,23 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
197197
}
198198
};
199199

200+
static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
201+
VectorType newType, AffineMap newMap,
202+
VectorType oldMaskType) {
203+
// Infer the type of the new mask from the new map.
204+
VectorType newMaskType = inferTransferOpMaskType(newType, newMap);
205+
206+
// If the new mask is broadcastable to the old result type, we can safely
207+
// use a `vector.extract` to get the new mask. Otherwise the best we can
208+
// do is shape cast.
209+
if (vector::isBroadcastableTo(newMaskType, oldMaskType) ==
210+
BroadcastableToResult::Success) {
211+
int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank();
212+
return b.create<vector::ExtractOp>(loc, mask, splatZero(dropDim));
213+
}
214+
return b.create<vector::ShapeCastOp>(loc, newMaskType, mask);
215+
}
216+
200217
// Turns vector.transfer_read on vector with leading 1 dimensions into
201218
// vector.shape_cast followed by vector.transfer_read on vector without leading
202219
// 1 dimensions.
@@ -234,11 +251,9 @@ struct CastAwayTransferReadLeadingOneDim
234251

235252
Value mask = Value();
236253
if (read.getMask()) {
237-
// The mask shape must always match the shape of the written vector, so we
238-
// can safely use the same extraction indices.
239-
int64_t dropDim = oldType.getRank() - newType.getRank();
240-
mask = rewriter.create<vector::ExtractOp>(read.getLoc(), read.getMask(),
241-
splatZero(dropDim));
254+
VectorType maskType = read.getMaskType();
255+
mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(),
256+
newType, newMap, maskType);
242257
}
243258

244259
auto newRead = rewriter.create<vector::TransferReadOp>(
@@ -289,10 +304,9 @@ struct CastAwayTransferWriteLeadingOneDim
289304
write.getLoc(), write.getVector(), splatZero(dropDim));
290305

291306
if (write.getMask()) {
292-
// The mask shape must always match the shape of the written vector, so we
293-
// can safely use the same extraction indices.
294-
auto newMask = rewriter.create<vector::ExtractOp>(
295-
write.getLoc(), write.getMask(), splatZero(dropDim));
307+
VectorType maskType = write.getMaskType();
308+
Value newMask = dropUnitDimsFromMask(
309+
rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
296310
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
297311
write, newVector, write.getSource(), write.getIndices(),
298312
AffineMapAttr::get(newMap), newMask, inBoundsAttr);

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,27 @@ func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x
232232
return %0: vector<1x1xf16>
233233
}
234234

235+
// -----
236+
237+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
238+
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
239+
func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
240+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
241+
%c0 = arith.constant 0 : index
242+
// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
243+
%f0 = arith.constant 0. : f16
244+
// CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
245+
// CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
246+
// CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
247+
// CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
248+
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
249+
permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
250+
// CHECK: return %[[CAST]]
251+
return %0: vector<1x1x4xf16>
252+
}
253+
254+
// -----
255+
235256
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
236257
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
237258
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -263,6 +284,25 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
263284
return
264285
}
265286

287+
// -----
288+
289+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
290+
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
291+
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
292+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
293+
%c0 = arith.constant 0 : index
294+
// CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
295+
// CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
296+
// CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
297+
// CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
298+
299+
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
300+
permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
301+
return
302+
}
303+
304+
// -----
305+
266306
// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
267307
func.func @cast_away_elementwise_leading_one_dims(
268308
%arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,

0 commit comments

Comments
 (0)