@@ -70,28 +70,29 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
70
70
71
71
// -----
72
72
73
- // The shape of the memref and the vector don't match, but the vector,
74
- // ignoring the unit dimensions, is a contiguous subset of the memref,
75
- // so "flattenable"
73
+ // The shape of the memref and the vector don't match, but the vector is a
74
+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75
+ // of the vector have no effect on the memref area read even if they
76
+ // span a non-contiguous part of the memref.
76
77
77
78
func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
78
- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
79
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
79
80
80
81
%c0 = arith.constant 0 : index
81
82
%cst = arith.constant 0 : i8
82
83
%res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
83
- memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
84
+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
84
85
return %res : vector <1 x1 x2 x2 xi8 >
85
86
}
86
87
87
88
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
88
- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
89
+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
89
90
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
90
91
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
91
92
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
92
93
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
93
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
94
- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>, vector<4xi8>
94
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
95
+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>, vector<4xi8>
95
96
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
96
97
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
97
98
@@ -412,31 +413,40 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
412
413
413
414
// -----
414
415
416
+ // The shape of the memref and the vector don't match, but the vector is a
417
+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
418
+ // of the vector have no effect on the memref area written even if they
419
+ // span a non-contiguous part of the memref.
420
+
415
421
func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
416
- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
422
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
417
423
%vec : vector <1 x1 x2 x2 xi8 >) {
418
424
419
425
%c0 = arith.constant 0 : index
420
426
vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
421
- vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
427
+ vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
422
428
return
423
429
}
424
430
425
431
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
426
- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>,
432
+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>,
427
433
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
428
434
// CHECK: %[[C0:.*]] = arith.constant 0 : index
429
435
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
430
436
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
431
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
437
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
432
438
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
433
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
439
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
440
+ // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
434
441
435
442
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
436
443
// CHECK-128B: memref.collapse_shape
437
444
438
445
// -----
439
446
447
+ // The shape of the memref and the vector don't match, but the vector is a
448
+ // contiguous subset of the memref, so "flattenable".
449
+
440
450
func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
441
451
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
442
452
%vec : vector <2 x2 xi8 >) {
0 commit comments