@@ -70,28 +70,29 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
70
70
71
71
// -----
72
72
73
- // The shape of the memref and the vector don't match, but the vector,
74
- // ignoring the unit dimensions, is a contiguous subset of the memref,
75
- // so "flattenable"
73
+ // The shape of the memref and the vector don't match, but the vector is a
74
+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75
+ // of the vector have no effect on the memref area read even if they
76
+ // span a non-contiguous part of the memref.
76
77
77
78
func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
78
- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
79
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
79
80
80
81
%c0 = arith.constant 0 : index
81
82
%cst = arith.constant 0 : i8
82
83
%res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
83
- memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
84
+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
84
85
return %res : vector <1 x1 x2 x2 xi8 >
85
86
}
86
87
87
88
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
88
- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
89
+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
89
90
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
90
91
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
91
92
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
92
93
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
93
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
94
- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>, vector<4xi8>
94
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
95
+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>, vector<4xi8>
95
96
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
96
97
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
97
98
@@ -416,31 +417,40 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
416
417
417
418
// -----
418
419
420
+ // The shape of the memref and the vector don't match, but the vector is a
421
+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
422
+ // of the vector have no effect on the memref area written even if they
423
+ // span a non-contiguous part of the memref.
424
+
419
425
func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
420
- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
426
+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
421
427
%vec : vector <1 x1 x2 x2 xi8 >) {
422
428
423
429
%c0 = arith.constant 0 : index
424
430
vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
425
- vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
431
+ vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
426
432
return
427
433
}
428
434
429
435
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
430
- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>,
436
+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>,
431
437
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
432
438
// CHECK: %[[C0:.*]] = arith.constant 0 : index
433
439
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
434
440
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
435
- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
441
+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
436
442
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
437
- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
443
+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
444
+ // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
438
445
439
446
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
440
447
// CHECK-128B: memref.collapse_shape
441
448
442
449
// -----
443
450
451
+ // The shape of the memref and the vector don't match, but the vector is a
452
+ // contiguous subset of the memref, so "flattenable".
453
+
444
454
func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
445
455
%mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
446
456
%vec : vector <2 x2 xi8 >) {
0 commit comments