@@ -275,3 +275,54 @@ transform.sequence failures(propagate) {
275
275
matmul_inner_dims_order = [1 , 2 , 0 ]
276
276
: (!transform.op <" linalg.generic" >) -> !transform.op <" linalg.generic" >
277
277
}
278
+
279
+
280
+ // -----
281
+
282
+ !A_mk = tensor <1023 x255 xf32 >
283
+ !B_nk = tensor <127 x255 xf32 >
284
+ !C_nm = tensor <127 x1023 xf32 >
285
+
286
+ #mkn_accesses = [
287
+ affine_map <(m , n , k ) -> (m , k )>,
288
+ affine_map <(m , n , k ) -> (n , k )>,
289
+ affine_map <(m , n , k ) -> (n , m )>
290
+ ]
291
+ #mkn_trait = {
292
+ indexing_maps = #mkn_accesses ,
293
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
294
+ }
295
+
296
+ // Normalized dims are: ( k, m, n)(kk, mm)
297
+ // CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d3)>
298
+ // CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d0, d3, d4)>
299
+ // CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1, d4)>
300
+
301
+ // CHECK-LABEL: @matmul_mk_nk_nm(
302
+ func.func @matmul_mk_nk_nm (%A : !A_mk , %B : !B_nk , %C : !C_nm ) -> !C_nm {
303
+ // CHECK: linalg.generic
304
+ // CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
305
+ // CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel"]}
306
+ // CHECK-SAME: ins(%{{.*}} : tensor<1023x8x32xf32>, tensor<1x8x32x130xf32>)
307
+ // CHECK-SAME: outs(%{{.*}} : tensor<1x1023x130xf32>)
308
+ %0 = linalg.generic #mkn_trait ins (%A , %B : !A_mk , !B_nk ) outs (%C : !C_nm ) {
309
+ ^bb0 (%a: f32 , %b: f32 , %c: f32 ):
310
+ %d = arith.mulf %a , %b : f32
311
+ %e = arith.addf %c , %d : f32
312
+ linalg.yield %e : f32
313
+ } -> !C_nm
314
+ return %0 : !C_nm
315
+ }
316
+
317
+ transform.sequence failures (propagate ) {
318
+ ^bb1 (%module_op: !pdl.operation ):
319
+ %generic = transform.structured.match ops {[" linalg.generic" ]} in %module_op : (!pdl.operation ) -> !transform.op <" linalg.generic" >
320
+ transform.structured.pack_greedily %generic
321
+ // In this spec, the "n" dimension is neither packed not unpacked.
322
+ // We don't end up with an innermost matmul after packing but only with an
323
+ // innermost matvec.
324
+ matmul_packed_sizes = [0 , 0 , 32 ]
325
+ matmul_padded_sizes_next_multiple_of = [0 , 10 , 0 ]
326
+ matmul_inner_dims_order = [1 , 2 , 0 ]
327
+ : (!transform.op <" linalg.generic" >) -> !transform.op <" linalg.generic" >
328
+ }
0 commit comments