@@ -163,6 +163,41 @@ module attributes {transform.with_named_sequence} {
163
163
164
164
// -----
165
165
166
+ // CHECK-LABEL: func.func @unpack_with_identity_outer_dims_perm(
167
+ func.func @unpack_with_identity_outer_dims_perm (%arg0: tensor <17 x2 x16 x16 x32 x8 xf32 >, %arg1: tensor <129 x47 x16 x16 xf32 >) -> tensor <129 x47 x16 x16 xf32 > {
168
+ %cst_0 = arith.constant 0.0 : f32
169
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
170
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
171
+ // CHECK: %[[TRAN:.*]] = linalg.transpose
172
+ // CHECK-SAME: ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
173
+ // CHECK-SAME: outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
174
+ // CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3]
175
+ // CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
176
+ // CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
177
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
178
+ // CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
179
+ // CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
180
+ // CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
181
+ %unpack = tensor.unpack %arg0 outer_dims_perm = [0 , 1 , 2 , 3 ] inner_dims_pos = [1 , 0 ] inner_tiles = [32 , 8 ] into %arg1
182
+ : tensor <17 x2 x16 x16 x32 x8 xf32 > -> tensor <129 x47 x16 x16 xf32 >
183
+ return %unpack : tensor <129 x47 x16 x16 xf32 >
184
+ }
185
+
186
+ module attributes {transform.with_named_sequence } {
187
+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
188
+ %unpack = transform.structured.match ops {[" tensor.unpack" ]} in %module_op
189
+ : (!transform.any_op ) -> !transform.op <" tensor.unpack" >
190
+ transform.structured.lower_unpack %unpack : (!transform.op <" tensor.unpack" >)
191
+ -> (!transform.op <" tensor.empty" >,
192
+ !transform.op <" linalg.transpose" >,
193
+ !transform.op <" tensor.collapse_shape" >,
194
+ !transform.op <" tensor.extract_slice" >)
195
+ transform.yield
196
+ }
197
+ }
198
+
199
+ // -----
200
+
166
201
// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
167
202
// CHECK-LABEL: func.func @unpack_as_pad(
168
203
func.func @unpack_as_pad (%arg0: tensor <1 x1 x1 x1 x136 x64 x16 x16 xf32 >, %arg1: tensor <129 x47 x16 x16 xf32 >) -> tensor <129 x47 x16 x16 xf32 > {
0 commit comments