@@ -131,10 +131,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
131
131
132
132
// -----
133
133
134
+ /// The leading dynamic shapes don't affect whether this example is flattenable
135
+ /// or not as those dynamic shapes are not candidates for flattening anyway.
136
+
137
+ func.func @transfer_read_leading_dynamic_dims (
138
+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
139
+ %idx_1 : index ,
140
+ %idx_2 : index ) -> vector <8 x4 xi8 > {
141
+
142
+ %c0_i8 = arith.constant 0 : i8
143
+ %c0 = arith.constant 0 : index
144
+ %result = vector.transfer_read %arg [%idx_1 , %idx_2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
145
+ return %result : vector <8 x4 xi8 >
146
+ }
147
+
148
+ // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
149
+ // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
150
+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
151
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
152
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
153
+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
154
+ // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
155
+ // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
156
+ // CHECK-SAME: {in_bounds = [true]}
157
+ // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
158
+ // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
159
+ // CHECK: return %[[VEC2D]] : vector<8x4xi8>
160
+
161
+ // CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
162
+ // CHECK-128B: memref.collapse_shape
163
+
164
+ // -----
165
+
134
166
// The input memref has a dynamic trailing shape and hence is not flattened.
135
167
// TODO: This case could be supported via memref.dim
136
168
137
- func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes (
169
+ func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
138
170
%idx_1: index ,
139
171
%idx_2: index ,
140
172
%m_in: memref <1 x?x4 x6 xi32 >) -> vector <1 x2 x6 xi32 > {
@@ -146,11 +178,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
146
178
return %v : vector <1 x2 x6 xi32 >
147
179
}
148
180
149
- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
181
+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
150
182
// CHECK-NOT: memref.collapse_shape
151
183
// CHECK-NOT: vector.shape_cast
152
184
153
- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
185
+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
154
186
// CHECK-128B-NOT: memref.collapse_shape
155
187
156
188
// -----
@@ -345,10 +377,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
345
377
346
378
// -----
347
379
380
+ // The leading dynamic shapes don't affect whether this example is flattenable
381
+ // or not as those dynamic shapes are not candidates for flattening anyway.
382
+
383
+ func.func @transfer_write_leading_dynamic_dims (
384
+ %vec : vector <8 x4 xi8 >,
385
+ %arg : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>,
386
+ %idx_1 : index ,
387
+ %idx_2 : index ) {
388
+
389
+ %c0 = arith.constant 0 : index
390
+ vector.transfer_write %vec , %arg [%idx_1 , %idx_2 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
391
+ return
392
+ }
393
+
394
+ // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
395
+ // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
396
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
397
+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
398
+ // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
399
+ // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
400
+ // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
401
+ // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
402
+ // CHECK-SAME: {in_bounds = [true]}
403
+ // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
404
+
405
+ // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
406
+ // CHECK-128B: memref.collapse_shape
407
+
408
+ // -----
409
+
348
410
// The input memref has a dynamic trailing shape and hence is not flattened.
349
411
// TODO: This case could be supported via memref.dim
350
412
351
- func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
413
+ func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
352
414
%idx_1: index ,
353
415
%idx_2: index ,
354
416
%vec : vector <1 x2 x6 xi32 >,
@@ -361,11 +423,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
361
423
return
362
424
}
363
425
364
- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
426
+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
365
427
// CHECK-NOT: memref.collapse_shape
366
428
// CHECK-NOT: vector.shape_cast
367
429
368
- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes (
430
+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim (
369
431
// CHECK-128B-NOT: memref.collapse_shape
370
432
371
433
// -----
@@ -434,56 +496,10 @@ func.func @transfer_write_non_contiguous_src(
434
496
// -----
435
497
436
498
///----------------------------------------------------------------------------------------
437
- /// TODO: Categorize + re-format
499
+ /// [Pattern: DropUnitDimFromElementwiseOps]
500
+ /// TODO: Move to a dedicated file - there's no "flattening" in the following tests
438
501
///----------------------------------------------------------------------------------------
439
502
440
- func.func @transfer_read_flattenable_with_dynamic_dims_and_indices (%arg0 : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) -> vector <8 x4 xi8 > {
441
- %c0_i8 = arith.constant 0 : i8
442
- %c0 = arith.constant 0 : index
443
- %result = vector.transfer_read %arg0 [%arg1 , %arg2 , %c0 , %c0 ], %c0_i8 {in_bounds = [true , true ]} : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, vector <8 x4 xi8 >
444
- return %result : vector <8 x4 xi8 >
445
- }
446
-
447
- // CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
448
- // CHECK-SAME: %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
449
- // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
450
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
451
- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
452
- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
453
- // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
454
- // CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
455
- // CHECK-SAME: {in_bounds = [true]}
456
- // CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
457
- // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
458
- // CHECK: return %[[VEC2D]] : vector<8x4xi8>
459
-
460
- // CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
461
- // CHECK-128B: memref.collapse_shape
462
-
463
- // -----
464
-
465
- func.func @transfer_write_flattenable_with_dynamic_dims_and_indices (%vec : vector <8 x4 xi8 >, %dst : memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>, %arg1 : index , %arg2 : index ) {
466
- %c0 = arith.constant 0 : index
467
- vector.transfer_write %vec , %dst [%arg1 , %arg2 , %c0 , %c0 ] {in_bounds = [true , true ]} : vector <8 x4 xi8 >, memref <?x?x8 x4 xi8 , strided <[?, 32 , 4 , 1 ], offset : ?>>
468
- return
469
- }
470
-
471
- // CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
472
- // CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
473
- // CHECK: %[[C0:.+]] = arith.constant 0 : index
474
- // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
475
- // CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
476
- // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
477
- // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
478
- // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
479
- // CHECK-SAME: {in_bounds = [true]}
480
- // CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
481
-
482
- // CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
483
- // CHECK-128B: memref.collapse_shape
484
-
485
- // -----
486
-
487
503
func.func @fold_unit_dim_add_basic (%arg0 : vector <1 x8 xi32 >) -> vector <1 x8 xi32 > {
488
504
%add = arith.addi %arg0 , %arg0 : vector <1 x8 xi32 >
489
505
return %add : vector <1 x8 xi32 >
0 commit comments