Skip to content

Commit aecd754

Browse files
committed
[mlir][vector] Update tests for collapse 3/n (nfc)
The main goal of this PR (and subsequent PRs), is to add more tests with scalable vectors to: * vector-transfer-collapse-inner-most-dims.mlir There's quite a few cases to consider, hence this is split into multiple PRs. In this PR, the very first test for `vector.transfer_write` is complemented with all the possible combinations: * scalable (rather than fixed) unit trailing dim, * dynamic (rather than static) trailing dim in the source memref. To this end, the following tests: * `@leading_scalable_dimension_transfer_write` `@trailing_scalable_one_dim_transfer_write` are replaced with: * `@drop_two_inner_most_dim_scalable_inner_dim` and `@negative_scalable_unit_dim`, respectively. In addition: * "_for_transfer_write" is removed from function names (to reduce noise). This is a follow-up for: #94490, #94604 NOTE: This PR is limited to tests for `vector.transfer_write`.
1 parent 77db8b0 commit aecd754

File tree

1 file changed

+57
-33
lines changed

1 file changed

+57
-33
lines changed

mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,14 @@ func.func @negative_non_unit_inner_memref_dim(%arg0: memref<4x8xf32>) -> vector<
254254
// 2. vector.transfer_write
255255
//-----------------------------------------------------------------------------
256256

257-
func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
257+
func.func @drop_two_inner_most_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
258258
%c0 = arith.constant 0 : index
259259
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
260260
{in_bounds = [true, true, true, true, true]}
261261
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
262262
return
263263
}
264-
// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
264+
// CHECK: func.func @drop_two_inner_most_dim
265265
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
266266
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
267267
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -272,16 +272,67 @@ func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1
272272
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
273273
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
274274

275+
// Same as the top example within this split, but with the inner vector
276+
// dim scalable. Note that this example only makes sense when "16 = [16]" (i.e.
277+
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.
278+
279+
func.func @drop_two_inner_most_dim_scalable_inner_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x[16]x1x1xf32>, %arg2: index) {
280+
%c0 = arith.constant 0 : index
281+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
282+
{in_bounds = [true, true, true, true, true]}
283+
: vector<1x16x[16]x1x1xf32>, memref<1x512x16x1x1xf32>
284+
return
285+
}
286+
// CHECK: func.func @drop_two_inner_most_dim_scalable_inner_dim
287+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
288+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
289+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
290+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
291+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
292+
// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>>
293+
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x[16]x1x1xf32> to vector<1x16x[16]xf32>
294+
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
295+
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
296+
297+
// Same as the top example within this split, but the trailing unit dim was
298+
// replaced with a dyn dim - not supported
299+
300+
func.func @negative_non_unit_trailing_dim(%arg0: memref<1x512x16x1x?xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
301+
%c0 = arith.constant 0 : index
302+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
303+
{in_bounds = [true, true, true, true, true]}
304+
: vector<1x16x16x1x1xf32>, memref<1x512x16x1x?xf32>
305+
return
306+
}
307+
// CHECK: func.func @negative_non_unit_trailing_dim
308+
// CHECK-NOT: memref.subview
309+
// CHECK-NOT: vector.shape_cast
310+
311+
// Same as the top example within this split, but with a scalable unit dim in
312+
// the output vector - not supported
313+
314+
func.func @negative_scalable_unit_dim(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x[1]xf32>, %arg2: index) {
315+
%c0 = arith.constant 0 : index
316+
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
317+
{in_bounds = [true, true, true, true, true]}
318+
: vector<1x16x16x1x[1]xf32>, memref<1x512x16x1x1xf32>
319+
return
320+
}
321+
322+
// CHECK: func.func @negative_scalable_unit_dim
323+
// CHECK-NOT: memref.subview
324+
// CHECK-NOT: vector.shape_cast
325+
275326
// -----
276327

277-
func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
328+
func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
278329
%c0 = arith.constant 0 : index
279330
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
280331
{in_bounds = [true, true, true, true]}
281332
: vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
282333
return
283334
}
284-
// CHECK: func.func @drop_inner_most_dim_for_transfer_write
335+
// CHECK: func.func @drop_inner_most_dim
285336
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
286337
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
287338
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -294,14 +345,14 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
294345

295346
// -----
296347

297-
func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
348+
func.func @outer_dyn_drop_inner_most_dim(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
298349
%c0 = arith.constant 0 : index
299350
vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
300351
{in_bounds = [true, true, true, true]}
301352
: vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
302353
return
303354
}
304-
// CHECK: func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
355+
// CHECK: func.func @outer_dyn_drop_inner_most_dim
305356
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
306357
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
307358
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
@@ -325,30 +376,3 @@ func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], o
325376
// The inner most unit dims can not be dropped if the strides are not ones.
326377
// CHECK: func.func @non_unit_strides
327378
// CHECK-NOT: memref.subview
328-
329-
// -----
330-
331-
func.func @leading_scalable_dimension_transfer_write(%dest : memref<24x1xf32>, %vec: vector<[4]x1xf32>) {
332-
%c0 = arith.constant 0 : index
333-
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x1xf32>, memref<24x1xf32>
334-
return
335-
}
336-
// CHECK: func.func @leading_scalable_dimension_transfer_write
337-
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
338-
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
339-
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0] [24, 1] [1, 1] : memref<24x1xf32> to memref<24xf32, strided<[1]>>
340-
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<[4]x1xf32> to vector<[4]xf32>
341-
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]{{.*}} {in_bounds = [true]} : vector<[4]xf32>, memref<24xf32, strided<[1]>>
342-
343-
// -----
344-
345-
// Negative test: [1] (scalable 1) is _not_ a unit dimension.
346-
func.func @trailing_scalable_one_dim_transfer_write(%dest : memref<24x1xf32>, %vec: vector<4x[1]xf32>, %index: index) {
347-
%c0 = arith.constant 0 : index
348-
vector.transfer_write %vec, %dest[%index, %c0] {in_bounds = [true, true]} : vector<4x[1]xf32>, memref<24x1xf32>
349-
return
350-
}
351-
// CHECK: func.func @trailing_scalable_one_dim_transfer_write
352-
// CHECK-NOT: vector.shape_cast
353-
// CHECK: vector.transfer_write {{.*}} : vector<4x[1]xf32>, memref<24x1xf32>
354-
// CHECK-NOT: vector.shape_cast

0 commit comments

Comments
 (0)