Skip to content

Commit bc80240

Browse files
authored
[mlir][sve][nfc] Merge the integration tests for linalg.matmul (#74059)
At the moment the logic to tile and vectorize `linalg.matmul` is duplicated in multiple test files: * matmul.mlir * matmul_mixed_ty.mlir Instead, this patch uses `transform.foreach` to apply the same sequence to multiple functions within the same test file (e.g. `matmul_f32` and `matmul_mixed_ty` as defined in the original files). This allows us to merge relevant test files.
1 parent ea4eb69 commit bc80240

File tree

2 files changed

+69
-96
lines changed

2 files changed

+69
-96
lines changed

mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir

Lines changed: 69 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
// RUN: %{compile}
1010

11-
// RUN: %{run} | FileCheck %s
11+
// RUN: %{run} | FileCheck %s --check-prefix=F32
12+
13+
// REDEFINE: %{entry_point} = matmul_mixed_ty
14+
// RUN: %{run} | FileCheck %s --check-prefix=MIXED
1215

1316
func.func @matmul_f32() {
1417
// Matrix dimensions
@@ -32,37 +35,75 @@ func.func @matmul_f32() {
3235
%C_out = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>) outs(%C_in: tensor<?x?xf32>) -> tensor<?x?xf32>
3336

3437
// Print and verify the output
35-
// CHECK-LABEL: SVE: START OF TEST OUTPUT
38+
// F32-LABEL: SVE: START OF TEST OUTPUT
3639
vector.print str "SVE: START OF TEST OUTPUT"
3740

38-
// CHECK-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
39-
// CHECK-COUNT-5: [29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788]
41+
// F32-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
42+
// F32-COUNT-5: [29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788]
4043
%xf = tensor.cast %C_out : tensor<?x?xf32> to tensor<*xf32>
4144
call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
4245

43-
// CHECK-NEXT: SVE: END OF TEST OUTPUT
46+
// F32-NEXT: SVE: END OF TEST OUTPUT
47+
vector.print str "SVE: END OF TEST OUTPUT"
48+
49+
return
50+
}
51+
52+
func.func @matmul_mixed_ty() {
53+
// Matrix dimensions
54+
%K = arith.constant 3 : index
55+
%M = arith.constant 5 : index
56+
%N = arith.constant 15 : index
57+
%c0_i8 = arith.constant 0 : i8
58+
%c0_i32 = arith.constant 0 : i32
59+
60+
// Allocate the matrices
61+
%A_alloc = bufferization.alloc_tensor(%M, %K) : tensor<?x?xi8>
62+
%B_alloc = bufferization.alloc_tensor(%K, %N) : tensor<?x?xi8>
63+
%C_alloc = bufferization.alloc_tensor(%M, %N) : tensor<?x?xi32>
64+
65+
// Initialise the matrices
66+
%pi = arith.constant 123 : i8
67+
%A = linalg.fill ins(%pi : i8) outs(%A_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
68+
%B = linalg.fill ins(%pi : i8) outs(%B_alloc : tensor<?x?xi8>) -> tensor<?x?xi8>
69+
%C_in = linalg.fill ins(%c0_i32 : i32) outs(%C_alloc : tensor<?x?xi32>) -> tensor<?x?xi32>
70+
71+
// Matmul
72+
%C_out = linalg.matmul ins(%A, %B: tensor<?x?xi8>, tensor<?x?xi8>) outs(%C_in: tensor<?x?xi32>) -> tensor<?x?xi32>
73+
74+
// Print and verify the output
75+
// MIXED-LABEL: SVE: START OF TEST OUTPUT
76+
vector.print str "SVE: START OF TEST OUTPUT"
77+
78+
// MIXED-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
79+
// MIXED-COUNT-5: [45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387]
80+
%xf = tensor.cast %C_out : tensor<?x?xi32> to tensor<*xi32>
81+
call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
82+
83+
// MIXED-NEXT: SVE: END OF TEST OUTPUT
4484
vector.print str "SVE: END OF TEST OUTPUT"
4585

4686
return
4787
}
4888

4989
module attributes {transform.with_named_sequence} {
50-
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
51-
%matmul = transform.structured.match ops{["linalg.matmul"]} in %module
52-
: (!transform.any_op) -> !transform.any_op
90+
// A sequence that will tile and vectorise a Matmul Op
91+
transform.named_sequence @tile_and_vectorize_matmul(%func
92+
: !transform.op<"func.func"> {transform.readonly}) {
93+
94+
// Step 0: Get a handle to the matmul Op
95+
%matmul = transform.structured.match ops{["linalg.matmul"]} in %func
96+
: (!transform.op<"func.func">) -> !transform.any_op
5397

5498
// Step 1: Tile
55-
%module_with_tiled_loops, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
99+
%tiled_matmul, %loops:3 = transform.structured.tile_using_for %matmul [2, [4], 1]
56100
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
101+
transform.print %tiled_matmul {name = "matmul lal"}: !transform.any_op
57102

58103
// Step 2: Vectorize
59-
%tiled_matmul = transform.structured.match ops{["linalg.matmul"]} in %module_with_tiled_loops
60-
: (!transform.any_op) -> !transform.any_op
61104
transform.structured.vectorize %tiled_matmul vector_sizes [2, [4], 1] : !transform.any_op
62105

63106
// Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
64-
%func = transform.structured.match ops{["func.func"]} in %module
65-
: (!transform.any_op) -> !transform.op<"func.func">
66107
transform.apply_patterns to %func {
67108
transform.apply_patterns.vector.reduction_to_contract
68109
transform.apply_patterns.vector.transfer_permutation_patterns
@@ -77,6 +118,21 @@ transform.named_sequence @__transform_main(%module: !transform.any_op {transform
77118

78119
transform.yield
79120
}
121+
122+
// A sequence that goes over all functions in tis module and applies
123+
// "tile_and_vectorize_matmul"
124+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
125+
%funcs = transform.structured.match ops{["func.func"]} in %module
126+
: (!transform.any_op) -> !transform.op<"func.func">
127+
128+
transform.foreach %funcs : !transform.op<"func.func"> {
129+
^bb2(%func : !transform.op<"func.func">):
130+
transform.include @tile_and_vectorize_matmul failures(propagate)
131+
(%func) : (!transform.op<"func.func">) -> ()
132+
}
133+
transform.yield
134+
}
80135
}
81136

82137
func.func private @printMemrefF32(%ptr : tensor<*xf32>)
138+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)

mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul_mixed_ty.mlir

Lines changed: 0 additions & 83 deletions
This file was deleted.

0 commit comments

Comments
 (0)