Skip to content

Commit 2b99b7c

Browse files
committed
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n)
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. The main contributions of this PR: 1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`, i.e. move it near other tests for xfer_read, unify variable names to match other xfer_read tests, highlight what makes this a positive test to better contrast it with `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim` 2. Similar changes for `@transfer_write_flattenable_with_dynamic_dims_and_indices`. Depends on llvm#95743 and llvm#95744 **Only review the top top commit**
1 parent 867ff2d commit 2b99b7c

File tree

1 file changed

+70
-54
lines changed

1 file changed

+70
-54
lines changed

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
131131

132132
// -----
133133

134+
/// The leading dynamic shapes don't affect whether this example is flattenable
135+
/// or not as those dynamic shapes are not candidates for flattening anyway.
136+
137+
func.func @transfer_read_leading_dynamic_dims(
138+
%arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
139+
%idx_1 : index,
140+
%idx_2 : index) -> vector<8x4xi8> {
141+
142+
%c0_i8 = arith.constant 0 : i8
143+
%c0 = arith.constant 0 : index
144+
%result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
145+
return %result : vector<8x4xi8>
146+
}
147+
148+
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
149+
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
150+
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
151+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
152+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
153+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
154+
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
155+
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
156+
// CHECK-SAME: {in_bounds = [true]}
157+
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
158+
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
159+
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
160+
161+
// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
162+
// CHECK-128B: memref.collapse_shape
163+
164+
// -----
165+
134166
// The input memref has a dynamic trailing shape and hence is not flattened.
135167
// TODO: This case could be supported via memref.dim
136168

137-
func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
169+
func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
138170
%idx_1: index,
139171
%idx_2: index,
140172
%m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -146,11 +178,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
146178
return %v : vector<1x2x6xi32>
147179
}
148180

149-
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
181+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
150182
// CHECK-NOT: memref.collapse_shape
151183
// CHECK-NOT: vector.shape_cast
152184

153-
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
185+
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
154186
// CHECK-128B-NOT: memref.collapse_shape
155187

156188
// -----
@@ -345,10 +377,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
345377

346378
// -----
347379

380+
// The leading dynamic shapes don't affect whether this example is flattenable
381+
// or not as those dynamic shapes are not candidates for flattening anyway.
382+
383+
func.func @transfer_write_leading_dynamic_dims(
384+
%vec : vector<8x4xi8>,
385+
%arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
386+
%idx_1 : index,
387+
%idx_2 : index) {
388+
389+
%c0 = arith.constant 0 : index
390+
vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
391+
return
392+
}
393+
394+
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
395+
// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
396+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
397+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
398+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
399+
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
400+
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
401+
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
402+
// CHECK-SAME: {in_bounds = [true]}
403+
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
404+
405+
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
406+
// CHECK-128B: memref.collapse_shape
407+
408+
// -----
409+
348410
// The input memref has a dynamic trailing shape and hence is not flattened.
349411
// TODO: This case could be supported via memref.dim
350412

351-
func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
413+
func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
352414
%idx_1: index,
353415
%idx_2: index,
354416
%vec : vector<1x2x6xi32>,
@@ -361,11 +423,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
361423
return
362424
}
363425

364-
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
426+
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
365427
// CHECK-NOT: memref.collapse_shape
366428
// CHECK-NOT: vector.shape_cast
367429

368-
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
430+
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
369431
// CHECK-128B-NOT: memref.collapse_shape
370432

371433
// -----
@@ -434,56 +496,10 @@ func.func @transfer_write_non_contiguous_src(
434496
// -----
435497

436498
///----------------------------------------------------------------------------------------
437-
/// TODO: Categorize + re-format
499+
/// [Pattern: DropUnitDimFromElementwiseOps]
500+
/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
438501
///----------------------------------------------------------------------------------------
439502

440-
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
441-
%c0_i8 = arith.constant 0 : i8
442-
%c0 = arith.constant 0 : index
443-
%result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
444-
return %result : vector<8x4xi8>
445-
}
446-
447-
// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
448-
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
449-
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
450-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
451-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
452-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
453-
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
454-
// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
455-
// CHECK-SAME: {in_bounds = [true]}
456-
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
457-
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
458-
// CHECK: return %[[VEC2D]] : vector<8x4xi8>
459-
460-
// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
461-
// CHECK-128B: memref.collapse_shape
462-
463-
// -----
464-
465-
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
466-
%c0 = arith.constant 0 : index
467-
vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
468-
return
469-
}
470-
471-
// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
472-
// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
473-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
474-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
475-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
476-
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
477-
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
478-
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
479-
// CHECK-SAME: {in_bounds = [true]}
480-
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
481-
482-
// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
483-
// CHECK-128B: memref.collapse_shape
484-
485-
// -----
486-
487503
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
488504
%add = arith.addi %arg0, %arg0 : vector<1x8xi32>
489505
return %add : vector<1x8xi32>

0 commit comments

Comments
 (0)