Skip to content

Commit dedc7d4

Browse files
author
Jerry Wu
authored
[mlir] Exclude masked ops in VectorDropLeadUnitDim (#76468)
Don't insert cast ops for ops in `vector.mask` region in `VectorDropLeadUnitDim`.
1 parent 975deb3 commit dedc7d4

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
223223

224224
LogicalResult matchAndRewrite(vector::TransferReadOp read,
225225
PatternRewriter &rewriter) const override {
226+
// TODO(#78787): Not supported masked op yet.
227+
if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
228+
return failure();
226229
// TODO: support 0-d corner case.
227230
if (read.getTransferRank() == 0)
228231
return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
274277

275278
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
276279
PatternRewriter &rewriter) const override {
280+
// TODO(#78787): Not supported masked op yet.
281+
if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
282+
return failure();
277283
// TODO: support 0-d corner case.
278284
if (write.getTransferRank() == 0)
279285
return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
325331
LogicalResult
326332
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
327333
RewriterBase &rewriter) {
334+
// TODO(#78787): Not supported masked op yet.
335+
if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
336+
return failure();
328337
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
329338
if (oldAccType == nullptr)
330339
return failure();

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
164164
return %0: vector<1x1x2x16xf32>
165165
}
166166

167+
// -----
168+
169+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
172+
173+
// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
174+
// CHECK: %[[MASK:.+]] = vector.constant_mask
175+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
176+
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
177+
// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
178+
// CHECK: return %[[RET]] : vector<1x16x16xf32>
179+
180+
#contraction_accesses0 = [
181+
affine_map<(l, i, j, k) -> (l, i, k)>,
182+
affine_map<(l, i, j, k) -> (l, k, j)>,
183+
affine_map<(l, i, j, k) -> (l, i, j)>
184+
]
185+
#contraction_trait0 = {
186+
indexing_maps = #contraction_accesses0,
187+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
188+
}
189+
190+
func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
191+
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
192+
%0 = vector.mask %mask {
193+
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
194+
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
195+
return %0 : vector<1x16x16xf32>
196+
}
197+
167198
// -----
168199
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
169200
func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
253284

254285
// -----
255286

287+
// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
288+
// CHECK: %[[MASK:.+]] = vector.constant_mask
289+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
290+
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
291+
// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
292+
// CHECK: return %[[RET]] : vector<1x4xf16>
293+
func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
294+
%c0 = arith.constant 0 : index
295+
%f0 = arith.constant 0. : f16
296+
%mask = vector.constant_mask [1, 3] : vector<1x4xi1>
297+
%ret = vector.mask %mask {
298+
vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
299+
} : vector<1x4xi1> -> vector<1x4xf16>
300+
return %ret: vector<1x4xf16>
301+
}
302+
303+
// -----
304+
256305
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
257306
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
258307
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
286335

287336
// -----
288337

338+
// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
339+
// CHECK: %[[MASK:.+]] = vector.constant_mask
340+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
341+
// CHECK: vector.mask %[[CASTED_MASK]] {
342+
// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
343+
// CHECK: return
344+
func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
345+
%c0 = arith.constant 0 : index
346+
%mask = vector.constant_mask [1, 3] : vector<1x4xi1>
347+
vector.mask %mask {
348+
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
349+
} : vector<1x4xi1>
350+
return
351+
}
352+
353+
// -----
354+
289355
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
290356
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
291357
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {

0 commit comments

Comments
 (0)