Skip to content

Commit 5f8c31b

Browse files
committed
[mlir][nfc] Add tests for linalg.mmt4d
linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244), but there virtually no tests in-tree. In the spirit of documenting through test, this PR adds a few basic examples.
1 parent 4502dc4 commit 5f8c31b

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
744744
-> tensor<2x16xf32>
745745
return %1 : tensor<2x16xf32>
746746
}
747+
748+
// -----
749+
750+
func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
751+
%B: tensor<16x16x8x1xf32>,
752+
%C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
753+
// expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
754+
%res = linalg.mmt4d
755+
ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
756+
outs(%C_in: tensor<16x16x8x1xf32>)
757+
-> tensor<16x16x8x1xf32>
758+
return %res : tensor<16x16x8x1xf32>
759+
}
760+
761+
// -----
762+
763+
func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
764+
%B: tensor<16x16x8x1xf32>,
765+
%C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
766+
// expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
767+
%res = linalg.mmt4d
768+
ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
769+
outs(%C_in: tensor<8x8xf32>)
770+
-> tensor<8x8xf32>
771+
return %res : tensor<8x8xf32>
772+
}

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
12191219

12201220
// -----
12211221

1222+
// CHECK-LABEL: func @mmt4d
1223+
func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
1224+
// CHECK: %{{.+}} = linalg.mmt4d
1225+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
1226+
// CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
1227+
%0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
1228+
return %0: tensor<10x80x8x4xf32>
1229+
}
1230+
1231+
// -----
1232+
12221233
// CHECK-LABEL: func @batch_mmt4d
12231234
func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
12241235
// CHECK: %{{.+}} = linalg.batch_mmt4d
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
2+
3+
func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
4+
%res = linalg.mmt4d
5+
ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
6+
outs(%C_in: tensor<16x16x8x8xf32>)
7+
-> tensor<16x16x8x8xf32>
8+
return %res : tensor<16x16x8x8xf32>
9+
}
10+
11+
module attributes {transform.with_named_sequence} {
12+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
13+
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
14+
15+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
16+
17+
// Step 1: Tile
18+
: (!transform.op<"func.func">) -> !transform.any_op
19+
// Tile parallel dims
20+
%tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
21+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
22+
// Tile reduction dims
23+
%tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
24+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
25+
26+
// Step 2: Vectorize
27+
transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
28+
29+
// Step 3: Simplify
30+
// vector.multi_reduction --> vector.contract
31+
// Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
32+
// and with the following split int parallel and reduction dims:
33+
// * parallel, parallel, reduction, parallel, parallel, reduction
34+
transform.apply_patterns to %func {
35+
transform.apply_patterns.vector.reduction_to_contract
36+
// Reduce the rank of xfer ops. This transforms vector.contract to be
37+
// more matmul-like and to enable the lowering to outer product Ops.
38+
transform.apply_patterns.vector.transfer_permutation_patterns
39+
} : !transform.op<"func.func">
40+
41+
// Hoisting and LICM - not strictly required
42+
%func_h = transform.structured.hoist_redundant_vector_transfers %func
43+
: (!transform.op<"func.func">) -> !transform.op<"func.func">
44+
%all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
45+
: (!transform.op<"func.func">) -> !transform.any_op
46+
transform.apply_licm to %all_loops : !transform.any_op
47+
transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
48+
49+
// Simplify the 6-dim vector.contract into a 3-dim matmul-like
50+
// vector.contract with the following split splitn parallel and reduction
51+
// dims:
52+
// * parallel, parallel, reduction
53+
transform.apply_patterns to %func_h {
54+
transform.apply_patterns.vector.reduction_to_contract
55+
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
56+
transform.apply_patterns.canonicalization
57+
} : !transform.op<"func.func">
58+
59+
// Step 4: Lower vector.contract to vector.fma
60+
transform.apply_patterns to %func_h {
61+
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
62+
transform.apply_patterns.vector.lower_outerproduct
63+
} : !transform.op<"func.func">
64+
transform.yield
65+
}
66+
}

mlir/test/Dialect/Linalg/vectorization.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,38 @@ module attributes {transform.with_named_sequence} {
639639

640640
// -----
641641

642+
func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
643+
linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
644+
outs(%C_in: memref<16x16x8x8xf32>)
645+
return
646+
}
647+
648+
// CHECK-LABEL: func.func @mmt4d(
649+
// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
650+
// CHECK: %[[VAL_3:.*]] = arith.constant 16 : index
651+
// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
652+
// CHECK: %[[VAL_5:.*]] = arith.constant 16 : index
653+
// CHECK: %[[VAL_6:.*]] = arith.constant 8 : index
654+
// CHECK: %[[VAL_7:.*]] = arith.constant 8 : index
655+
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
656+
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
657+
// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
658+
// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
659+
// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
660+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
661+
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
662+
// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
663+
664+
module attributes {transform.with_named_sequence} {
665+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
666+
%mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
667+
transform.structured.vectorize %mmt4d : !transform.any_op
668+
transform.yield
669+
}
670+
}
671+
672+
// -----
673+
642674
func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
643675
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
644676
outs(%C: memref<?x?xf32>)

0 commit comments

Comments
 (0)