Skip to content

Commit 7145bd2

Browse files
[fixup] Do minimal collapsing of the memref
1 parent 3b17c94 commit 7145bd2

File tree

2 files changed

+69
-75
lines changed

2 files changed

+69
-75
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -581,17 +581,6 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
581581
}
582582

583583
namespace {
584-
585-
/// Helper function to return the index of the last dynamic dimension
586-
/// in `shape` or -1 if there are no dynamic dimensions.
587-
int64_t lastDynIndex(ArrayRef<int64_t> shape) {
588-
return static_cast<int64_t>(
589-
std::distance(
590-
std::find(shape.rbegin(), shape.rend(), ShapedType::kDynamic),
591-
shape.rend()) -
592-
1);
593-
}
594-
595584
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
596585
/// memref.collapse_shape on the source so that the resulting
597586
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -640,10 +629,11 @@ class FlattenContiguousRowMajorTransferReadPattern
640629
if (transferReadOp.getMask())
641630
return failure();
642631

643-
// Determinine the first memref dimension to collapse
644-
int64_t firstDimToCollapse = std::max(
645-
lastDynIndex(sourceType.getShape()),
646-
sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
632+
// Determine the first memref dimension to collapse - just enough so we can
633+
// read a flattened vector.
634+
int64_t firstDimToCollapse =
635+
sourceType.getRank() -
636+
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
647637

648638
// 1. Collapse the source memref
649639
Value collapsedSource =
@@ -735,10 +725,11 @@ class FlattenContiguousRowMajorTransferWritePattern
735725
if (transferWriteOp.getMask())
736726
return failure();
737727

738-
// Determinine the first memref dimension to collapse
739-
int64_t firstDimToCollapse = std::max(
740-
lastDynIndex(sourceType.getShape()),
741-
sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
728+
// Determine the first memref dimension to collapse - just enough so we can
729+
// read a flattened vector.
730+
int64_t firstDimToCollapse =
731+
sourceType.getRank() -
732+
vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
742733

743734
// 1. Collapse the source memref
744735
Value collapsedSource =

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
8888
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8989
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
9090
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
91-
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
92-
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
91+
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
92+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
93+
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
94+
// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>, vector<4xi8>
9395
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9496
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9597

@@ -116,10 +118,10 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
116118
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
117119
// CHECK: %[[C0:.+]] = arith.constant 0 : index
118120
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
119-
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
120-
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
121-
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
122-
// CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
121+
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
122+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
123+
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]], %[[C0]]], %[[C0_I8]] {in_bounds = [true]}
124+
// CHECK-SAME: : memref<5x24xi8, strided<[24, 1], offset: ?>>, vector<12xi8>
123125
// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
124126
// CHECK: return %[[VEC]] : vector<2x3x2xi8>
125127

@@ -141,17 +143,18 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
141143
return %res : vector<1x2x6xi32>
142144
}
143145

144-
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
146+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 * 6)>
145147

146148
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
147149
// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
148150
// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
149-
// CHECK: %[[C_0:.+]] = arith.constant 0 : i32
151+
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
152+
// CHECK: %[[C_0:.+]] = arith.constant 0 : index
150153
// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
151-
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
152-
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1032xi32>
153-
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
154-
// CHECK: %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
154+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
155+
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
156+
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]]]
157+
// CHECK: %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0]], %[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I32]] {in_bounds = [true]} : memref<1x43x24xi32>, vector<12xi32>
155158

156159
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
157160
// CHECK-128B-NOT: memref.collapse_shape
@@ -202,18 +205,16 @@ func.func @transfer_read_leading_dynamic_dims(
202205
return %res : vector<8x4xi8>
203206
}
204207

205-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
206-
207208
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
208209
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
209210
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
210-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
211-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
212-
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
211+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
212+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
213+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
214+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
213215
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
214-
// CHECK-SAME: [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
215-
// CHECK-SAME: {in_bounds = [true]}
216-
// CHECK-SAME: : memref<?x?xi8, {{.+}}>, vector<32xi8>
216+
// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
217+
// CHECK-SAME: {in_bounds = [true]} : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
217218
// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
218219
// CHECK: return %[[RES]] : vector<8x4xi8>
219220

@@ -259,7 +260,7 @@ func.func @transfer_read_dynamic_dim_to_flatten(
259260
return %res : vector<1x2x6xi32>
260261
}
261262

262-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
263+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 6)>
263264

264265
// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
265266
// CHECK-SAME: %[[IDX_1:arg0]]
@@ -268,11 +269,11 @@ func.func @transfer_read_dynamic_dim_to_flatten(
268269
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
269270
// CHECK: %[[C0:.+]] = arith.constant 0 : index
270271
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
271-
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
272-
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
273-
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
274-
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
275-
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
272+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
273+
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>
274+
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
275+
// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[IDX_1]], %[[COLLAPSED_IDX]]],
276+
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?x24xi32>, vector<12xi32>
276277
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
277278
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
278279

@@ -424,10 +425,12 @@ func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
424425
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
425426
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
426427
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
427-
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
428-
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
429-
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
430-
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
428+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
429+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
430+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
431+
// CHECK-SAME: : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
432+
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
433+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
431434

432435
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
433436
// CHECK-128B: memref.collapse_shape
@@ -447,13 +450,13 @@ func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
447450
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
448451
// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
449452
// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
450-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
451-
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
452-
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
453-
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
454-
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
455-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
456-
// CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
453+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
454+
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
455+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
456+
// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
457+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
458+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
459+
// CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
457460

458461
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
459462
// CHECK-128B: memref.collapse_shape
@@ -473,16 +476,18 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
473476
return
474477
}
475478

476-
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
479+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 * 6)>
477480

478481
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices(
479482
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
480483
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>,
481484
// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) {
482-
// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
483-
// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
485+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
486+
// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]]]
487+
// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]]
488+
// CHECK-DAG-SAME{LITERAL}: [[0], [1], [2, 3]] : memref<1x43x4x6xi32> into memref<1x43x24xi32>
484489
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
485-
// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
490+
// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[C0]], %[[IDX_1]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x43x24xi32>
486491

487492
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
488493
// CHECK-128B-NOT: memref.collapse_shape
@@ -530,24 +535,22 @@ func.func @transfer_write_leading_dynamic_dims(
530535
return
531536
}
532537

533-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
534-
535538
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
536539
// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
537-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
538-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
539-
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
540+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
541+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
542+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
543+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
540544
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
541545
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
542-
// CHECK-SAME: [%[[ARG2]], %[[COLLAPSED_IDX]]]
543-
// CHECK-SAME: {in_bounds = [true]}
544-
// CHECK-SAME: : vector<32xi8>, memref<?x?xi8, {{.+}}>
546+
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] {in_bounds = [true]}
547+
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
545548

546549
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
547550
// CHECK-128B: memref.collapse_shape
548551

549552
// -----
550-
553+
551554
// The vector could be a non-contiguous slice of the input
552555
// memref.
553556

@@ -583,7 +586,7 @@ func.func @transfer_write_dynamic_to_flatten(
583586
return
584587
}
585588

586-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
589+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 6)>
587590

588591
// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
589592
// CHECK-SAME: %[[IDX_1:arg0]]: index
@@ -592,12 +595,12 @@ func.func @transfer_write_dynamic_to_flatten(
592595
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
593596
// CHECK: %[[C0:.+]] = arith.constant 0 : index
594597
// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
595-
// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
596-
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
597-
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
598+
// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
599+
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?x24xi32>
600+
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
598601
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
599-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
600-
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
602+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[IDX_1]], %[[COLLAPSED_IDX]]]
603+
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?x24xi32>
601604

602605
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
603606
// CHECK-128B-NOT: memref.collapse_shape

0 commit comments

Comments
 (0)