@@ -99,7 +99,9 @@ func.func @expand_collapse_shape_static(
99
99
%arg4: memref <1 x5 xf32 , strided <[5 , 1 ], offset : ?>>,
100
100
%arg5: memref <f32 >,
101
101
%arg6: memref <3 x4 x5 xf32 , strided <[240 , 60 , 10 ], offset : 0 >>,
102
- %arg7: memref <1 x2049 xi64 , strided <[?, ?], offset : ?>>) {
102
+ %arg7: memref <1 x2049 xi64 , strided <[?, ?], offset : ?>>,
103
+ %arg8: memref <1 x1 x1024 xi8 , strided <[40960 , 4096 , 1 ], offset : 0 >>,
104
+ %arg9: memref <24 x1 x1 x1024 xi8 , strided <[40960 , 40960 , 4096 , 1 ], offset : 0 >>) {
103
105
// Reshapes that collapse and expand back a contiguous buffer.
104
106
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
105
107
// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
@@ -163,6 +165,19 @@ func.func @expand_collapse_shape_static(
163
165
memref <1 x2049 xi64 , strided <[?, ?], offset : ?>> into
164
166
memref <2049 xi64 , strided <[?], offset : ?>>
165
167
168
+ // %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
169
+ // %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
170
+
171
+ // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
172
+ %r8 = memref.collapse_shape %arg8 [[0 , 1 , 2 ]] :
173
+ memref <1 x1 x1024 xi8 , strided <[40960 , 4096 , 1 ], offset : 0 >> into
174
+ memref <1024 xi8 , strided <[1 ], offset : 0 >>
175
+
176
+ // CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
177
+ %r9 = memref.collapse_shape %arg9 [[0 ], [1 , 2 , 3 ]] :
178
+ memref <24 x1 x1 x1024 xi8 , strided <[40960 , 40960 , 4096 , 1 ], offset : 0 >> into
179
+ memref <24 x1024 xi8 , strided <[40960 , 1 ], offset : 0 >>
180
+
166
181
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
167
182
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
168
183
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
0 commit comments