Skip to content

Commit 0a3c888

Browse files
author
Jerry Wu
committed
Add tests
1 parent 4bba066 commit 0a3c888

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,37 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
164164
return %0: vector<1x1x2x16xf32>
165165
}
166166

167+
// -----
168+
169+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
170+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
171+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
172+
173+
// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
174+
// CHECK: %[[MASK:.+]] = vector.constant_mask
175+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
176+
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
177+
// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
178+
// CHECK: return %[[RET]] : vector<1x16x16xf32>
179+
180+
#contraction_accesses0 = [
181+
affine_map<(l, i, j, k) -> (l, i, k)>,
182+
affine_map<(l, i, j, k) -> (l, k, j)>,
183+
affine_map<(l, i, j, k) -> (l, i, j)>
184+
]
185+
#contraction_trait0 = {
186+
indexing_maps = #contraction_accesses0,
187+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]
188+
}
189+
190+
func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
191+
%mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
192+
%0 = vector.mask %mask {
193+
vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
194+
} : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
195+
return %0 : vector<1x16x16xf32>
196+
}
197+
167198
// -----
168199
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
169200
func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
@@ -253,6 +284,24 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16
253284

254285
// -----
255286

287+
// CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
288+
// CHECK: %[[MASK:.+]] = vector.constant_mask
289+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
290+
// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
291+
// CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
292+
// CHECK: return %[[RET]] : vector<1x4xf16>
293+
func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
294+
%c0 = arith.constant 0 : index
295+
%f0 = arith.constant 0. : f16
296+
%mask = vector.constant_mask [1, 3] : vector<1x4xi1>
297+
%ret = vector.mask %mask {
298+
vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
299+
} : vector<1x4xi1> -> vector<1x4xf16>
300+
return %ret: vector<1x4xf16>
301+
}
302+
303+
// -----
304+
256305
// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
257306
func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
258307
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -286,6 +335,23 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1
286335

287336
// -----
288337

338+
// CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
339+
// CHECK: %[[MASK:.+]] = vector.constant_mask
340+
// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
341+
// CHECK: vector.mask %[[CASTED_MASK]] {
342+
// CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
343+
// CHECK: return
344+
func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
345+
%c0 = arith.constant 0 : index
346+
%mask = vector.constant_mask [1, 3] : vector<1x4xi1>
347+
vector.mask %mask {
348+
vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
349+
} : vector<1x4xi1>
350+
return
351+
}
352+
353+
// -----
354+
289355
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
290356
// CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
291357
func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {

0 commit comments

Comments
 (0)