@@ -114,3 +114,234 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
114
114
// CHECK-LABEL: func.func @pad_pack_different_padding_value
115
115
// CHECK: tensor.pad
116
116
// CHECK: tensor.pack
117
+
118
+ // -----
119
+
120
+ func.func @tensor_pack_linalg_transpose_fold (%arg0: tensor <56 x57 x1 x64 xf32 >) -> tensor <1 x57 x56 x2 x32 xf32 > {
121
+ %0 = tensor.empty () : tensor <56 x2 x1 x57 x32 xf32 >
122
+ %pack = tensor.pack %arg0
123
+ outer_dims_perm = [0 , 3 , 2 , 1 ]
124
+ inner_dims_pos = [3 ]
125
+ inner_tiles = [32 ]
126
+ into %0 : tensor <56 x57 x1 x64 xf32 > -> tensor <56 x2 x1 x57 x32 xf32 >
127
+
128
+ %1 = tensor.empty () : tensor <1 x57 x56 x2 x32 xf32 >
129
+ %transposed = linalg.transpose
130
+ ins (%pack : tensor <56 x2 x1 x57 x32 xf32 >)
131
+ outs (%1 : tensor <1 x57 x56 x2 x32 xf32 >)
132
+ permutation = [2 , 3 , 0 , 1 , 4 ]
133
+ return %transposed : tensor <1 x57 x56 x2 x32 xf32 >
134
+ }
135
+ // CHECK: func @tensor_pack_linalg_transpose_fold(
136
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
137
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
138
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
139
+ // CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
140
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
141
+ // CHECK-SAME: into %[[INIT]]
142
+ // CHECK: return %[[PACK]]
143
+
144
+ // -----
145
+
146
+ func.func @tensor_pack_linalg_transpose_fold_with_padding (%arg0: tensor <56 x57 x1 x55 xf32 >, %padding: f32 ) -> tensor <1 x57 x56 x2 x32 xf32 > {
147
+ %0 = tensor.empty () : tensor <56 x2 x1 x57 x32 xf32 >
148
+ %pack = tensor.pack %arg0 padding_value (%padding : f32 )
149
+ outer_dims_perm = [0 , 3 , 2 , 1 ]
150
+ inner_dims_pos = [3 ]
151
+ inner_tiles = [32 ]
152
+ into %0 : tensor <56 x57 x1 x55 xf32 > -> tensor <56 x2 x1 x57 x32 xf32 >
153
+
154
+ %1 = tensor.empty () : tensor <1 x57 x56 x2 x32 xf32 >
155
+ %transposed = linalg.transpose
156
+ ins (%pack : tensor <56 x2 x1 x57 x32 xf32 >)
157
+ outs (%1 : tensor <1 x57 x56 x2 x32 xf32 >)
158
+ permutation = [2 , 3 , 0 , 1 , 4 ]
159
+ return %transposed : tensor <1 x57 x56 x2 x32 xf32 >
160
+ }
161
+ // CHECK: func @tensor_pack_linalg_transpose_fold_with_padding(
162
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
163
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
164
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
165
+ // CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
166
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
167
+ // CHECK-SAME: into %[[INIT]]
168
+ // CHECK: return %[[PACK]]
169
+
170
+ // -----
171
+
172
+ func.func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm (%arg0: tensor <56 x57 x1 x64 xf32 >) -> tensor <1 x2 x56 x57 x32 xf32 > {
173
+ %0 = tensor.empty () : tensor <56 x57 x1 x2 x32 xf32 >
174
+ %pack = tensor.pack %arg0
175
+ inner_dims_pos = [3 ]
176
+ inner_tiles = [32 ]
177
+ into %0 : tensor <56 x57 x1 x64 xf32 > -> tensor <56 x57 x1 x2 x32 xf32 >
178
+
179
+ %1 = tensor.empty () : tensor <1 x2 x56 x57 x32 xf32 >
180
+ %transposed = linalg.transpose
181
+ ins (%pack : tensor <56 x57 x1 x2 x32 xf32 >)
182
+ outs (%1 : tensor <1 x2 x56 x57 x32 xf32 >)
183
+ permutation = [2 , 3 , 0 , 1 , 4 ]
184
+ return %transposed : tensor <1 x2 x56 x57 x32 xf32 >
185
+ }
186
+ // CHECK: func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm(
187
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
188
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
189
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
190
+ // CHECK-SAME: outer_dims_perm = [2, 3, 0, 1]
191
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
192
+ // CHECK-SAME: into %[[INIT]]
193
+ // CHECK: return %[[PACK]]
194
+
195
+ // -----
196
+
197
+ func.func @tensor_pack_linalg_transpose_fold_tile_dims_transpose (%arg0: tensor <56 x72 x24 x128 xf32 >) -> tensor <12 x56 x4 x9 x32 x8 x2 xf32 > {
198
+ %0 = tensor.empty () : tensor <4 x9 x12 x56 x8 x2 x32 xf32 >
199
+ %pack = tensor.pack %arg0
200
+ outer_dims_perm = [3 , 1 , 2 , 0 ]
201
+ inner_dims_pos = [1 , 2 , 3 ]
202
+ inner_tiles = [8 , 2 , 32 ]
203
+ into %0 : tensor <56 x72 x24 x128 xf32 > -> tensor <4 x9 x12 x56 x8 x2 x32 xf32 >
204
+
205
+ %1 = tensor.empty () : tensor <12 x56 x4 x9 x32 x8 x2 xf32 >
206
+ %transposed = linalg.transpose
207
+ ins (%pack : tensor <4 x9 x12 x56 x8 x2 x32 xf32 >)
208
+ outs (%1 : tensor <12 x56 x4 x9 x32 x8 x2 xf32 >)
209
+ permutation = [2 , 3 , 0 , 1 , 6 , 4 , 5 ]
210
+ return %transposed : tensor <12 x56 x4 x9 x32 x8 x2 xf32 >
211
+ }
212
+ // CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_transpose(
213
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x72x24x128xf32>)
214
+ // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<12x56x4x9x32x8x2xf32>
215
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
216
+ // CHECK-SAME: outer_dims_perm = [2, 0, 3, 1]
217
+ // CHECK-SAME: inner_dims_pos = [3, 1, 2] inner_tiles = [32, 8, 2]
218
+ // CHECK-SAME: into %[[INIT]]
219
+ // CHECK: return %[[PACK]]
220
+
221
+ // -----
222
+
223
+ func.func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose (%arg0: tensor <56 x72 x24 x128 xf32 >) -> tensor <9 x56 x2 x12 x32 x8 x4 xf32 > {
224
+ %0 = tensor.empty () : tensor <4 x12 x9 x56 x8 x2 x32 xf32 >
225
+ %pack = tensor.pack %arg0
226
+ outer_dims_perm = [3 , 2 , 1 , 0 ]
227
+ inner_dims_pos = [1 , 2 , 3 ]
228
+ inner_tiles = [8 , 2 , 32 ]
229
+ into %0 : tensor <56 x72 x24 x128 xf32 > -> tensor <4 x12 x9 x56 x8 x2 x32 xf32 >
230
+
231
+ %1 = tensor.empty () : tensor <9 x56 x2 x12 x32 x8 x4 xf32 >
232
+ %transposed = linalg.transpose
233
+ ins (%pack : tensor <4 x12 x9 x56 x8 x2 x32 xf32 >)
234
+ outs (%1 : tensor <9 x56 x2 x12 x32 x8 x4 xf32 >)
235
+ permutation = [2 , 3 , 5 , 1 , 6 , 4 , 0 ]
236
+ return %transposed : tensor <9 x56 x2 x12 x32 x8 x4 xf32 >
237
+ }
238
+ // CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose(
239
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x72x24x128xf32>)
240
+ // CHECK: tensor.pack
241
+ // CHECK: linalg.transpose
242
+
243
+ // -----
244
+
245
+ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims (%arg0: tensor <56 x?x?x64 xf32 >) -> tensor <?x?x56 x2 x32 xf32 > {
246
+ %0 = tensor.empty () : tensor <56 x2 x1 x57 x32 xf32 >
247
+ %pack = tensor.pack %arg0
248
+ outer_dims_perm = [0 , 3 , 2 , 1 ]
249
+ inner_dims_pos = [3 ]
250
+ inner_tiles = [32 ]
251
+ into %0 : tensor <56 x?x?x64 xf32 > -> tensor <56 x2 x1 x57 x32 xf32 >
252
+
253
+ %1 = tensor.empty () : tensor <1 x57 x56 x2 x32 xf32 >
254
+ %transposed = linalg.transpose
255
+ ins (%pack : tensor <56 x2 x1 x57 x32 xf32 >)
256
+ outs (%1 : tensor <1 x57 x56 x2 x32 xf32 >)
257
+ permutation = [2 , 3 , 0 , 1 , 4 ]
258
+
259
+ %return_value = tensor.cast %transposed : tensor <1 x57 x56 x2 x32 xf32 > to tensor <?x?x56 x2 x32 xf32 >
260
+ return %return_value : tensor <?x?x56 x2 x32 xf32 >
261
+ }
262
+ // CHECK: func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims(
263
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x?x?x64xf32>)
264
+ // CHECK: %[[c1:.+]] = arith.constant 1 : index
265
+ // CHECK: %[[c2:.+]] = arith.constant 2 : index
266
+ // CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<56x?x?x64xf32>
267
+ // CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<56x?x?x64xf32>
268
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[dim_0]], %[[dim]]) : tensor<?x?x56x2x32xf32>
269
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
270
+ // CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
271
+ // CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
272
+ // CHECK-SAME: into %[[INIT]]
273
+ // CHECK: return %[[PACK]]
274
+
275
+ // -----
276
+
277
+ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_and_tile_dims (%arg0: tensor <56 x?x?x128 xf32 >) -> tensor <?x?x56 x9 x32 x8 x2 xf32 > {
278
+ %0 = tensor.empty () : tensor <56 x9 x12 x4 x8 x2 x32 xf32 >
279
+ %pack = tensor.pack %arg0
280
+ inner_dims_pos = [1 , 2 , 3 ]
281
+ inner_tiles = [8 , 2 , 32 ]
282
+ into %0 : tensor <56 x?x?x128 xf32 > -> tensor <56 x9 x12 x4 x8 x2 x32 xf32 >
283
+
284
+ %1 = tensor.empty () : tensor <12 x4 x56 x9 x32 x8 x2 xf32 >
285
+ %transposed = linalg.transpose
286
+ ins (%pack : tensor <56 x9 x12 x4 x8 x2 x32 xf32 >)
287
+ outs (%1 : tensor <12 x4 x56 x9 x32 x8 x2 xf32 >)
288
+ permutation = [2 , 3 , 0 , 1 , 6 , 4 , 5 ]
289
+
290
+ %return_value = tensor.cast %transposed : tensor <12 x4 x56 x9 x32 x8 x2 xf32 > to tensor <?x?x56 x9 x32 x8 x2 xf32 >
291
+ return %return_value : tensor <?x?x56 x9 x32 x8 x2 xf32 >
292
+ }
293
+ // CHECK: #[[map:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
294
+ // CHECK: #[[map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
295
+ // CHECK: module {
296
+ // CHECK: func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_and_tile_dims(
297
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<56x?x?x128xf32>)
298
+ // CHECK: %[[c1:.+]] = arith.constant 1 : index
299
+ // CHECK: %[[c2:.+]] = arith.constant 2 : index
300
+ // CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<56x?x?x128xf32>
301
+ // CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<56x?x?x128xf32>
302
+ // CHECK: %[[mapped_dim1:.+]] = affine.apply #[[map:.+]]()[%[[dim]]]
303
+ // CHECK: %[[mapped_dim2:.+]] = affine.apply #[[map1:.+]]()[%[[dim_0]]]
304
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[mapped_dim2]], %[[mapped_dim1]]) : tensor<?x4x56x?x32x8x2xf32>
305
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 3, 0, 1] inner_dims_pos = [3, 1, 2] inner_tiles = [32, 8, 2] into %[[INIT]] : tensor<56x?x?x128xf32> -> tensor<?x4x56x?x32x8x2xf32>
306
+ // CHECK: %[[CAST:.+]] = tensor.cast %[[PACK]] : tensor<?x4x56x?x32x8x2xf32> to tensor<?x?x56x9x32x8x2xf32>
307
+ // CHECK: return %[[CAST]] : tensor<?x?x56x9x32x8x2xf32>
308
+ // CHECK: }
309
+
310
+ // -----
311
+
312
+ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_sizes (%arg0: tensor <?x?x?x?xf32 >, %pack_dest: tensor <?x?x?x?x?x?x?xf32 >, %transpose_dest: tensor <?x?x?x?x?x?x?xf32 >, %tile_p : index , %tile_q : index , %tile_r : index ) -> tensor <?x?x?x?x?x?x?xf32 > {
313
+ %pack = tensor.pack %arg0
314
+ outer_dims_perm = [3 , 0 , 2 , 1 ]
315
+ inner_dims_pos = [1 , 2 , 3 ]
316
+ inner_tiles = [%tile_p , %tile_q , %tile_r ]
317
+ into %pack_dest : tensor <?x?x?x?xf32 > -> tensor <?x?x?x?x?x?x?xf32 >
318
+
319
+ %transposed = linalg.transpose
320
+ ins (%pack : tensor <?x?x?x?x?x?x?xf32 >)
321
+ outs (%transpose_dest : tensor <?x?x?x?x?x?x?xf32 >)
322
+ permutation = [2 , 3 , 0 , 1 , 6 , 4 , 5 ]
323
+
324
+ return %transposed : tensor <?x?x?x?x?x?x?xf32 >
325
+ }
326
+ // CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
327
+ // CHECK: module {
328
+ // CHECK: func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_sizes(
329
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>,
330
+ // CHECK-SAME: %[[PACK_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[TRANSPOSE_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>,
331
+ // CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
332
+ // CHECK-SAME: %[[ARG3:.+]]: index)
333
+ // CHECK: %[[c0:.+]] = arith.constant 0 : index
334
+ // CHECK: %[[c1:.+]] = arith.constant 1 : index
335
+ // CHECK: %[[c2:.+]] = arith.constant 2 : index
336
+ // CHECK: %[[c3:.+]] = arith.constant 3 : index
337
+ // CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c0]] : tensor<?x?x?x?xf32>
338
+ // CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<?x?x?x?xf32>
339
+ // CHECK: %[[dim_1:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<?x?x?x?xf32>
340
+ // CHECK: %[[dim_2:.+]] = tensor.dim %[[ARG0]], %[[c3]] : tensor<?x?x?x?xf32>
341
+ // CHECK: %[[mapped_dim0:.+]] = affine.apply #[[map:.+]]()[%[[dim_2]], %[[ARG3]]]
342
+ // CHECK: %[[mapped_dim1:.+]] = affine.apply #[[map:.+]]()[%[[dim_0]], %[[ARG1]]]
343
+ // CHECK: %[[mapped_dim2:.+]] = affine.apply #[[map:.+]]()[%[[dim_1]], %[[ARG2]]]
344
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[mapped_dim2]], %[[mapped_dim1]], %[[mapped_dim0]], %[[dim]], %[[ARG3]], %[[ARG1]], %[[ARG2]]) : tensor<?x?x?x?x?x?x?xf32>
345
+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
346
+ // CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
347
+ // CHECK: }
0 commit comments