Skip to content

Commit d6850be

Browse files
authored
[mlir][linalg] Add e2e test for linalg.mmt4d (#81790)
Follow-up for #81422. My intention is to write an e2e test targetting SVE, but more work is needed. Sending this as an intermiedate step.
1 parent bb029a5 commit d6850be

File tree

1 file changed

+121
-0
lines changed
  • mlir/test/Integration/Dialect/Linalg/CPU

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// DEFINE: %{compile} = mlir-opt %s \
2+
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3+
// DEFINE: -one-shot-bufferize -func-bufferize -cse -canonicalize -convert-vector-to-scf -test-lower-to-llvm -o %t
4+
// DEFINE: %{entry_point} = mmt4d
5+
// DEFINE: %{run} = mlir-cpu-runner %t -e %{entry_point} -entry-point-result=void \
6+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
7+
8+
// RUN: %{compile}
9+
10+
// RUN: %{run} | FileCheck %s
11+
12+
func.func @mmt4d() {
13+
// Allocate the matrices
14+
%A_alloc = tensor.empty() : tensor<2x2x3x1xi32>
15+
%B_alloc = tensor.empty() : tensor<2x2x3x1xi32>
16+
%C_alloc = tensor.empty() : tensor<2x2x3x3xi32>
17+
%C_in = arith.constant dense<[
18+
[[[ 1, 2, 3],
19+
[ 4, 5, 6],
20+
[ 7, 8, 9]],
21+
[[ 11, 12, 13],
22+
[ 14, 15, 16],
23+
[ 17, 18, 19]]],
24+
[[[ 21, 22, 23],
25+
[ 24, 25, 26],
26+
[ 27, 28, 29]],
27+
[[ 31, 32, 33],
28+
[ 34, 35, 36],
29+
[ 37, 38, 39]]]
30+
]> : tensor<2x2x3x3xi32>
31+
32+
// Initialise the matrices
33+
%three = arith.constant 3 : i32
34+
%four = arith.constant 4 : i32
35+
%A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<2x2x3x1xi32>) -> tensor<2x2x3x1xi32>
36+
%B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<2x2x3x1xi32>) -> tensor<2x2x3x1xi32>
37+
38+
// Matmul
39+
%C_out = linalg.mmt4d ins(%A, %B: tensor<2x2x3x1xi32>, tensor<2x2x3x1xi32>) outs(%C_in: tensor<2x2x3x3xi32>) -> tensor<2x2x3x3xi32>
40+
41+
// Print and verify the output
42+
// CHECK: Unranked Memref {{.*}} rank = 4 offset = 0 sizes = [2, 2, 3, 3] strides = [18, 9, 3, 1] data =
43+
// C[0, 0]
44+
// CHECK-NEXT: [25, 26, 27]
45+
// CHECK-NEXT: [28, 29, 30]
46+
// CHECK-NEXT: [31, 32, 33]
47+
// C[0, 1]
48+
// CHECK-NEXT: [35, 36, 37]
49+
// CHECK-NEXT: [38, 39, 40]
50+
// CHECK-NEXT: [41, 42, 43]
51+
// C[1, 0]
52+
// CHECK-NEXT: [45, 46, 47]
53+
// CHECK-NEXT: [48, 49, 50]
54+
// CHECK-NEXT: [51, 52, 53]
55+
// C[1, 1]
56+
// CHECK-NEXT: [55, 56, 57]
57+
// CHECK-NEXT: [58, 59, 60]
58+
// CHECK-NEXT: [61, 62, 63]
59+
60+
%xf = tensor.cast %C_out : tensor<2x2x3x3xi32> to tensor<*xi32>
61+
call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
62+
63+
return
64+
}
65+
66+
module @transforms attributes { transform.with_named_sequence } {
67+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
68+
%mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
69+
%func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
70+
71+
// Step 1: Tile
72+
// Tile parallel dims
73+
%tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 3, 3, 0]
74+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
75+
// Tile reduction dims
76+
%tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
77+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
78+
79+
// Step 2: Vectorize
80+
transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
81+
82+
// Step 3: Simplify
83+
// vector.multi_reduction --> vector.contract
84+
// Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
85+
// and with the following split into parallel and reduction dims:
86+
// * parallel, parallel, reduction, parallel, parallel, reduction
87+
transform.apply_patterns to %func {
88+
transform.apply_patterns.vector.reduction_to_contract
89+
// Reduce the rank of xfer ops. This transforms vector.contract to be
90+
// more matmul-like and to enable the lowering to outer product Ops.
91+
transform.apply_patterns.vector.transfer_permutation_patterns
92+
} : !transform.op<"func.func">
93+
94+
// Hoisting and LICM - not strictly required
95+
%func_h = transform.structured.hoist_redundant_vector_transfers %func
96+
: (!transform.op<"func.func">) -> !transform.op<"func.func">
97+
%all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
98+
: (!transform.op<"func.func">) -> !transform.any_op
99+
transform.apply_licm to %all_loops : !transform.any_op
100+
transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
101+
102+
// Simplify the 6-dim vector.contract into a 3-dim matmul-like
103+
// vector.contract with the following split into parallel and reduction
104+
// dims:
105+
// * parallel, parallel, reduction
106+
transform.apply_patterns to %func_h {
107+
transform.apply_patterns.vector.reduction_to_contract
108+
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
109+
transform.apply_patterns.canonicalization
110+
} : !transform.op<"func.func">
111+
transform.yield
112+
}
113+
114+
transform.named_sequence @match_mmt4d(
115+
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
116+
transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
117+
transform.yield %entry : !transform.any_op
118+
}
119+
}
120+
121+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)

0 commit comments

Comments
 (0)