Skip to content

Commit 76c0798

Browse files
authored
[mlir][memref]: Allow collapse dummy strided unit dim (#103719)
Dimensions of size 1 should be skipped, because their strides are meaningless and could have any arbitrary value.
1 parent 6528157 commit 76c0798

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,6 +2448,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
24482448
if (strict && (stride.saturated || srcStride.saturated))
24492449
return failure();
24502450

2451+
// Dimensions of size 1 should be skipped, because their strides are
2452+
// meaningless and could have any arbitrary value.
2453+
if (srcShape[idx - 1] == 1)
2454+
continue;
2455+
24512456
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
24522457
return failure();
24532458
}

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ func.func @expand_collapse_shape_static(
9999
%arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>,
100100
%arg5: memref<f32>,
101101
%arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>,
102-
%arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>) {
102+
%arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>,
103+
%arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
104+
%arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
103105
// Reshapes that collapse and expand back a contiguous buffer.
104106
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
105107
// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
@@ -163,6 +165,19 @@ func.func @expand_collapse_shape_static(
163165
memref<1x2049xi64, strided<[?, ?], offset: ?>> into
164166
memref<2049xi64, strided<[?], offset: ?>>
165167

168+
// %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
169+
// %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
170+
171+
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
172+
%r8 = memref.collapse_shape %arg8 [[0, 1, 2]] :
173+
memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>> into
174+
memref<1024xi8, strided<[1], offset: 0>>
175+
176+
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
177+
%r9 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
178+
memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>> into
179+
memref<24x1024xi8, strided<[40960, 1], offset: 0>>
180+
166181
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
167182
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
168183
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>

0 commit comments

Comments
 (0)