Skip to content

Commit d9a470e

Browse files
[MLIR] Fix incorrect slice contiguity inference in vector::isContiguousSlice
Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous.
1 parent 1d025f3 commit d9a470e

File tree

6 files changed

+126
-78
lines changed

6 files changed

+126
-78
lines changed

mlir/include/mlir/Dialect/Utils/IndexingUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ class ArrayAttr;
4040
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
4141
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
4242
///
43-
/// `sizes` elements are asserted to be non-negative.
43+
/// `sizes` element `s0` is asserted to be kDynamic or non-negative.
44+
/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
4445
///
4546
/// Return an empty vector if `sizes` is empty.
4647
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
4949

5050
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
5151
///
52-
/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
53-
/// checked (the other dims are not relevant). Note that for `vectorType` to be
54-
/// a contiguous slice of `memrefType`, the trailing dims of the latter have
55-
/// to be contiguous - this is checked by looking at the corresponding strides.
52+
/// The leading unit dimensions of the vector type are ignored as they
53+
/// are not relevant to the result. Let N be the number of the vector
54+
/// dimensions after ignoring a leading sequence of unit ones.
5655
///
57-
/// There might be some restriction on the leading dim of `VectorType`:
56+
/// For `vectorType` to be a contiguous slice of `memrefType`
57+
/// a) the N trailing dimensions of the latter must be contiguous, and
58+
/// b) the trailing N dimensions of `vectorType` and `memrefType`,
59+
/// except the first of them, must match.
5860
///
59-
/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
60-
/// of `memrefType` then the leading dim of `vectorType` can be
61-
/// arbitrary.
62-
///
63-
/// Ex. 1.1 contiguous slice, perfect match
64-
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
65-
/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
66-
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
67-
///
68-
/// Case 2. If an "internal" dim of `vectorType` does not match the
69-
/// corresponding trailing dim in `memrefType` then the remaining
70-
/// leading dims of `vectorType` have to be 1 (the first non-matching
71-
/// dim can be arbitrary).
61+
/// Examples:
7262
///
73-
/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
74-
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
75-
/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
76-
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
77-
/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
78-
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
79-
/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
80-
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
63+
/// Ex.1 contiguous slice, perfect match
64+
/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
65+
/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
66+
/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
67+
/// Ex.3 non-contiguous slice, 2 != 3
68+
/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
69+
/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
70+
/// 2 != 3 (allowed)
71+
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
72+
/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
73+
/// 2 != 3 (allowed)
74+
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
75+
/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
76+
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
77+
/// Ex.7 contiguous slice, memref needs to be contiguous only on the last
78+
/// dimension
79+
/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
80+
/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last
81+
/// two dimensions, and it isn't
82+
/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
8183
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
8284

8385
/// Returns an iterator for all positions in the leading dimensions of `vType`

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
6969
//===----------------------------------------------------------------------===//
7070

7171
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
72-
assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
73-
"sizes must be nonnegative");
72+
assert(sizes.size() == 0 ||
73+
((sizes[0] == ShapedType::kDynamic || sizes[0] >= 0) &&
74+
llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) &&
75+
"sizes must be nonnegative");
7476
int64_t unit = 1;
7577
return ::computeSuffixProductImpl(sizes, unit);
7678
}

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
630630
if (transferReadOp.getMask())
631631
return failure();
632632

633-
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
633+
// Determinine the first memref dimension to collapse
634+
int64_t firstDimToCollapse =
635+
sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
634636

635637
// 1. Collapse the source memref
636638
Value collapsedSource =
@@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern
722724
if (transferWriteOp.getMask())
723725
return failure();
724726

725-
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
727+
// Determinine the first memref dimension to collapse
728+
int64_t firstDimToCollapse =
729+
sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
726730

727731
// 1. Collapse the source memref
728732
Value collapsedSource =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
258258
if (vectorType.isScalable())
259259
return false;
260260

261-
ArrayRef<int64_t> vectorShape = vectorType.getShape();
262-
auto vecRank = vectorType.getRank();
261+
// Ignore a leading contiguous sequence of unit dimensions in the vector.
262+
ArrayRef<int64_t> vectorShape =
263+
vectorType.getShape().drop_while([](auto v) { return v == 1; });
264+
auto vecRank = vectorShape.size();
263265

264266
if (!memrefType.areTrailingDimsContiguous(vecRank))
265267
return false;
266268

267-
// Extract the trailing dims and strides of the input memref
269+
// Extract the trailing dims of the input memref
268270
auto memrefShape = memrefType.getShape().take_back(vecRank);
269271

270-
// Compare the dims of `vectorType` against `memrefType` (in reverse).
271-
// In the most basic case, all dims will match.
272-
auto firstNonMatchingDim =
273-
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
274-
memrefShape.rbegin(), memrefShape.rend());
275-
if (firstNonMatchingDim.first == vectorShape.rend())
276-
return true;
277-
278-
// One non-matching dim is still fine, however the remaining leading dims of
279-
// `vectorType` need to be 1.
280-
SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
281-
vectorShape.rend());
282-
283-
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
272+
// Compare the dims of `vectorType` against `memrefType`.
273+
// All of the dimensions, except the first must match.
274+
return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
284275
}
285276

286277
std::optional<StaticTileOffsetRange>

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
116116
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
117117
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>
118118
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
119-
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
120-
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
119+
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
120+
// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
121+
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1032xi32>
121122
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
122-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
123+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
123124

124125
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
125126
// CHECK-128B-NOT: memref.collapse_shape
@@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims(
170171
return %res : vector<8x4xi8>
171172
}
172173

174+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
175+
173176
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
174177
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
175178
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
176-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
177-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
178-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
179+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
180+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
181+
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
179182
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
180-
// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
183+
// CHECK-SAME: [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
181184
// CHECK-SAME: {in_bounds = [true]}
182-
// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
185+
// CHECK-SAME: : memref<?x?xi8, {{.+}}>, vector<32xi8>
183186
// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
184187
// CHECK: return %[[RES]] : vector<8x4xi8>
185188

@@ -210,13 +213,12 @@ func.func @transfer_read_dynamic_dim_to_flatten(
210213
// CHECK-SAME: %[[IDX_2:arg1]]
211214
// CHECK-SAME: %[[MEM:arg2]]
212215
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
213-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
214216
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
215-
// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
216-
// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
217+
// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
218+
// CHECK-SAME: memref<1x?x4x6xi32> into memref<?xi32>
217219
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
218-
// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
219-
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
220+
// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
221+
// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
220222
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
221223
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
222224

@@ -397,11 +399,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
397399
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
398400
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>,
399401
// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) {
400-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
401402
// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
402-
// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
403+
// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
403404
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
404-
// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
405+
// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
405406

406407
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
407408
// CHECK-128B-NOT: memref.collapse_shape
@@ -449,16 +450,18 @@ func.func @transfer_write_leading_dynamic_dims(
449450
return
450451
}
451452

453+
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
454+
452455
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
453456
// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
454-
// CHECK: %[[C0:.+]] = arith.constant 0 : index
455-
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
456-
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
457+
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
458+
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
459+
// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
457460
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
458461
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
459-
// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
462+
// CHECK-SAME: [%[[ARG2]], %[[COLLAPSED_IDX]]]
460463
// CHECK-SAME: {in_bounds = [true]}
461-
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
464+
// CHECK-SAME: : vector<32xi8>, memref<?x?xi8, {{.+}}>
462465

463466
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
464467
// CHECK-128B: memref.collapse_shape
@@ -488,14 +491,13 @@ func.func @transfer_write_dynamic_to_flatten(
488491
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
489492
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
490493

491-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
492494
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
493-
// CHECK-SAME-LITERAL: [[0], [1, 2, 3]]
494-
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
495+
// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
496+
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<?xi32>
495497
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
496498
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
497-
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
498-
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
499+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
500+
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
499501

500502
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
501503
// CHECK-128B-NOT: memref.collapse_shape
@@ -573,8 +575,12 @@ func.func @negative_out_of_bound_transfer_read(
573575
memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
574576
return %res : vector<5x4x3x2xi8>
575577
}
576-
// CHECK: func.func @negative_out_of_bound_transfer_read
577-
// CHECK-NOT: memref.collapse_shape
578+
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
579+
// CHECK-NOT: memref.collapse_shape
580+
581+
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
582+
// CHECK-128B-NOT: memref.collapse_shape
583+
// CHECK-128B-NOT: vector.shape_cast
578584

579585
// -----
580586

@@ -585,5 +591,47 @@ func.func @negative_out_of_bound_transfer_write(
585591
vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
586592
return
587593
}
588-
// CHECK: func.func @negative_out_of_bound_transfer_write
589-
// CHECK-NOT: memref.collapse_shape
594+
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
595+
// CHECK-NOT: memref.collapse_shape
596+
597+
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
598+
// CHECK-128B-NOT: memref.collapse_shape
599+
// CHECK-128B-NOT: vector.shape_cast
600+
601+
// -----
602+
603+
func.func @discontig_mem_contig_slice(
604+
%mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
605+
%c0 = arith.constant 0 : index
606+
vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
607+
vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
608+
return
609+
}
610+
611+
// CHECK-LABEL: func.func @discontig_mem_contig_slice
612+
// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
613+
// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
614+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
615+
// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
616+
// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
617+
// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
618+
619+
// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
620+
// CHECK-128B-NOT: vector.shape_cast
621+
622+
// -----
623+
624+
func.func @discontig_mem_discontig_slice(
625+
%mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
626+
%c0 = arith.constant 0 : index
627+
vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
628+
vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
629+
return
630+
}
631+
632+
// CHECK-LABEL: func.func @discontig_mem_discontig_slice
633+
// CHECK-NOT: vector.shape_cast
634+
635+
// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
636+
// CHECK-128B-NOT: vector.shape_cast
637+

0 commit comments

Comments
 (0)