Skip to content

Commit d775398

Browse files
authored
[mlir][linalg] Add e2e test for linalg.mmt4d + pack/unpack (#84964)
This is a follow-up for #81790. This patch basically extends: * test/Integration/Dialect/Linalg/CPU/mmt4d.mlir with pack/unpack ops so that to overall computation is a matrix multiplication (as opposed to linalg.mmt4d). For comparison (and to make it easier to verify correctness), linalg.matmul is also included in the test.
1 parent d7975c9 commit d775398

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// DEFINE: %{compile} = mlir-opt %s \
2+
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3+
// DEFINE: -one-shot-bufferize="bufferize-function-boundaries" \
4+
// DEFINE: -buffer-deallocation-pipeline="private-function-dynamic-ownership" \
5+
// DEFINE: -cse -canonicalize -test-lower-to-llvm
6+
// DEFINE: %{entry_point} = main
7+
// DEFINE: %{run} = mlir-cpu-runner -e %{entry_point} -entry-point-result=void \
8+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
9+
10+
// RUN: %{compile} | %{run} | FileCheck %s
11+
12+
/// End-to-end test for computing matrix-multiplication using linalg.mmt4d. In
13+
/// particular, demonstrates how the following MLIR sequence (implemented in @mmt4d):
14+
///
15+
/// A_pack = tensor.pack A
16+
/// B_pack = tensor.pack B
17+
/// C_pack = tensor.pack C
18+
/// out_pack = linalg.mmt4d(A_pack, B_pack, C_pack)
19+
///
20+
/// is equivalent to:
21+
///
22+
/// linalg.matmul(A, B, C)
23+
///
24+
/// (implemented in @matmul).
25+
26+
func.func @main() {
27+
// Allocate and initialise the inputs
28+
%A_alloc = tensor.empty() : tensor<7x16xi32>
29+
%B_alloc = tensor.empty() : tensor<16x13xi32>
30+
31+
%three = arith.constant 3 : i32
32+
%four = arith.constant 4 : i32
33+
%A = linalg.fill ins(%three : i32) outs(%A_alloc : tensor<7x16xi32>) -> tensor<7x16xi32>
34+
%B = linalg.fill ins(%four : i32) outs(%B_alloc : tensor<16x13xi32>) -> tensor<16x13xi32>
35+
%C = arith.constant dense<[
36+
[ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85],
37+
[ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86],
38+
[ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87],
39+
[ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88],
40+
[ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89],
41+
[ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90],
42+
[ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91]
43+
]> : tensor<7x13xi32>
44+
45+
// Matrix multiplication via linalg.mmt4d
46+
// CHECK: Unranked Memref
47+
// CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
48+
// CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
49+
// CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
50+
// CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
51+
// CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
52+
// CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
53+
// CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
54+
%C_mmt4d = func.call @mmt4d(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
55+
%xf = tensor.cast %C_mmt4d : tensor<7x13xi32> to tensor<*xi32>
56+
call @printMemrefI32(%xf) : (tensor<*xi32>) -> ()
57+
58+
// Matrix multiplication with linalg.matmul
59+
// CHECK: Unranked Memref
60+
// CHECK: [193, 200, 207, 214, 221, 228, 235, 242, 249, 256, 263, 270, 277]
61+
// CHECK: [194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278]
62+
// CHECK: [195, 202, 209, 216, 223, 230, 237, 244, 251, 258, 265, 272, 279]
63+
// CHECK: [196, 203, 210, 217, 224, 231, 238, 245, 252, 259, 266, 273, 280]
64+
// CHECK: [197, 204, 211, 218, 225, 232, 239, 246, 253, 260, 267, 274, 281]
65+
// CHECK: [198, 205, 212, 219, 226, 233, 240, 247, 254, 261, 268, 275, 282]
66+
// CHECK: [199, 206, 213, 220, 227, 234, 241, 248, 255, 262, 269, 276, 283]
67+
%C_matmul = func.call @matmul(%A, %B, %C) : (tensor<7x16xi32>, tensor<16x13xi32>, tensor<7x13xi32>) -> tensor<7x13xi32>
68+
%xf_2 = tensor.cast %C_matmul : tensor<7x13xi32> to tensor<*xi32>
69+
call @printMemrefI32(%xf_2) : (tensor<*xi32>) -> ()
70+
71+
return
72+
}
73+
74+
func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
75+
%C_matmul = linalg.matmul ins(%A, %B: tensor<7x16xi32>, tensor<16x13xi32>)
76+
outs(%C: tensor<7x13xi32>) -> tensor<7x13xi32>
77+
78+
return %C_matmul : tensor<7x13xi32>
79+
}
80+
81+
func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
82+
%zero = arith.constant 0 : i32
83+
84+
%A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
85+
%B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
86+
%C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
87+
88+
// Pack matrices
89+
%A_pack = tensor.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
90+
%B_pack = tensor.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
91+
%C_pack = tensor.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
92+
93+
// MMT4D
94+
%mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32>
95+
96+
// Unpack output
97+
%C_out_empty = tensor.empty() : tensor<7x13xi32>
98+
%C_out_unpack = tensor.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32>
99+
100+
return %C_out_unpack : tensor<7x13xi32>
101+
}
102+
103+
module @transforms attributes { transform.with_named_sequence } {
104+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
105+
%mmt4d = transform.collect_matching @match_mmt4d in %module : (!transform.any_op) -> (!transform.any_op)
106+
%func = transform.get_parent_op %mmt4d {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
107+
108+
// Step 1: Tile
109+
// Tile parallel dims
110+
%tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
111+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
112+
// Tile reduction dims
113+
%tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
114+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
115+
116+
// Step 2: Vectorize
117+
transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
118+
119+
// Step 3: Simplify
120+
// vector.multi_reduction --> vector.contract
121+
// Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
122+
// and with the following split into parallel and reduction dims:
123+
// * parallel, parallel, reduction, parallel, parallel, reduction
124+
transform.apply_patterns to %func {
125+
transform.apply_patterns.vector.reduction_to_contract
126+
// Reduce the rank of xfer ops. This transforms vector.contract to be
127+
// more matmul-like and to enable the lowering to outer product Ops.
128+
transform.apply_patterns.vector.transfer_permutation_patterns
129+
} : !transform.op<"func.func">
130+
131+
// Hoisting and LICM - not strictly required
132+
%func_h = transform.structured.hoist_redundant_vector_transfers %func
133+
: (!transform.op<"func.func">) -> !transform.op<"func.func">
134+
%all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
135+
: (!transform.op<"func.func">) -> !transform.any_op
136+
transform.apply_licm to %all_loops : !transform.any_op
137+
transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
138+
139+
// Simplify the 6-dim vector.contract into a 3-dim matmul-like
140+
// vector.contract with the following split into parallel and reduction
141+
// dims:
142+
// * parallel, parallel, reduction
143+
transform.apply_patterns to %func_h {
144+
transform.apply_patterns.vector.reduction_to_contract
145+
transform.apply_patterns.vector.cast_away_vector_leading_one_dim
146+
transform.apply_patterns.canonicalization
147+
} : !transform.op<"func.func">
148+
149+
// Step 4. Lower tensor.pack
150+
%pack = transform.structured.match ops{["tensor.pack"]} in %func_h
151+
: (!transform.op<"func.func">) -> !transform.op<"tensor.pack">
152+
transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
153+
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
154+
155+
// Step 5. Lower tensor.unpack
156+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %func_h
157+
: (!transform.op<"func.func">) -> !transform.op<"tensor.unpack">
158+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
159+
-> (!transform.op<"tensor.empty">,
160+
!transform.op<"linalg.transpose">,
161+
!transform.op<"tensor.collapse_shape">,
162+
!transform.op<"tensor.extract_slice">)
163+
transform.yield
164+
}
165+
166+
transform.named_sequence @match_mmt4d(
167+
%entry: !transform.any_op {transform.readonly}) -> !transform.any_op {
168+
transform.match.operation_name %entry ["linalg.mmt4d"] : !transform.any_op
169+
transform.yield %entry : !transform.any_op
170+
}
171+
}
172+
173+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)

0 commit comments

Comments
 (0)