Skip to content

Commit 9f0a912

Browse files
authored
[mlir][test][sve] Add e2e test for linalg.pack + linalg.unpack (#129696)
This patch adds an e2e test for the `linalg.pack` + `linalg.unpack` pair with a dynamic inner tile size that's tied to SVE's "vscale": ```mlir %c4 = arith.constant 4 : index %vs = vector.vscale %tile_size = arith.muli %c4, %vs : index ``` This means that the actual size of the corresponding inner and outer tile size will depend on the runtime value of "vscale". To make the new test deterministic (and to make it easier to experiment), I have hard-coded the value of "vscale" to 2 via (2 x 128 bits = 256 bits): ```mlir `func.call @setArmVLBits(%c256) : (i32) -> () ``` This can be relaxed at a later time or played with when experimenting locally with e.g. QEMU. NOTE: Vectorization has not been enabled yet (scalable vectorization of `linalg.unpack` is still WIP).
1 parent 878a64f commit 9f0a912

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// DEFINE: %{compile} = mlir-opt %s \
2+
// DEFINE: -transform-interpreter -test-transform-dialect-erase-schedule \
3+
// DEFINE: --lower-vector-mask |\
4+
// DEFINE: mlir-opt -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve"\
5+
// DEFINE: -test-lower-to-llvm -o %t
6+
// DEFINE: %{entry_point} = main
7+
// DEFINE: %{run} = mlir-cpu-runner %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
8+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
9+
10+
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
11+
12+
/// End-to-end test for linalg.pack + linalg.unpack where one of the inner tile sizes is
13+
/// scalable.
14+
/// NOTE: Vectorization has not been enabled yet!
15+
16+
17+
/// The main entry point
18+
func.func @main() {
19+
// Set vscale to 2 (vector width = 256). This will have identical effect to:
20+
// * qemu-aarch64 -cpu max,sve-max-vq=2 (...)
21+
// (If your platform supports it, you can play with other values as well)
22+
%c256 = arith.constant 256 : i32
23+
func.call @setArmVLBits(%c256) : (i32) -> ()
24+
25+
// Dynamic/scalable tile size (vscale x 4)
26+
%c4 = arith.constant 4 : index
27+
%vs = vector.vscale
28+
%tile_size = arith.muli %c4, %vs : index
29+
30+
vector.print str "\nINNER TILE SIZE (run-time value): "
31+
vector.print %tile_size : index
32+
33+
// Input matrix. The values and dimension have been selected so that this
34+
// matrix can be viewed as:
35+
// +--------+--------+--------+
36+
// | | | |
37+
// | 4x4 | 4x4 | 4x4 |
38+
// | | | |
39+
// +--------+--------+--------+
40+
// | | | |
41+
// | 3x4 | 3x4 | 3x4 |
42+
// | | | |
43+
// +--------+--------+--------+
44+
// This way, after packing, there will be "incomplete" tiles that will
45+
// contain the padding value. After unpacking, the padding value should be
46+
// gone.
47+
%A_before = arith.constant dense<[
48+
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
49+
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
50+
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
51+
[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
52+
[4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6],
53+
[4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6],
54+
[4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]
55+
]> : tensor<7x12xi32>
56+
57+
// STEP 1: PACK + UNPACK
58+
// TODO: We should change the order to: Pack+print, Unpack+print. However, that causes the
59+
// bufferization to fail with:
60+
// * 'tensor.cast' op not bufferizable under the given constraints: cannot avoid RaW conflict
61+
// Investigate and either fix or remove this comment (if impossible to work-around).
62+
%A_pack = func.call @pack_main(%A_before, %tile_size) : (tensor<7x12xi32>, index) -> tensor<2x?x4x?xi32>
63+
%A_unpack = func.call @unpack_main(%A_pack, %tile_size) : (tensor<2x?x4x?xi32>, index) -> tensor<7x12xi32>
64+
65+
// STEP 2: Print the matrices
66+
vector.print str "\nINPUT MATRIX (before packing)\n"
67+
%A_before_cast = tensor.cast %A_before : tensor<7x12xi32> to tensor<*xi32>
68+
call @printMemrefI32(%A_before_cast) : (tensor<*xi32>) -> ()
69+
70+
vector.print str "\nINPUT MATRIX (after packing)\n"
71+
%A_pack_cast = tensor.cast %A_pack : tensor<2x?x4x?xi32> to tensor<*xi32>
72+
// There ought to be at least one pad value inserted into a tile
73+
// CHECK-LABEL: (after packing)
74+
// CHECK: 123
75+
call @printMemrefI32(%A_pack_cast) : (tensor<*xi32>) -> ()
76+
77+
vector.print str "\nINPUT MATRIX (after unpacking)\n"
78+
%A_unpack_cast = tensor.cast %A_unpack : tensor<7x12xi32> to tensor<*xi32>
79+
// This ought to match the input matrix
80+
// CHECK-LABEL: (after unpacking)
81+
// CHECK: [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
82+
// CHECK: [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
83+
// CHECK: [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
84+
// CHECK: [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
85+
// CHECK: [4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6],
86+
// CHECK: [4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6],
87+
// CHECK: [4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]
88+
call @printMemrefI32(%A_unpack_cast) : (tensor<*xi32>) -> ()
89+
90+
return
91+
}
92+
93+
/// Takes the unpacked matrix + inner tile size to use and return the packed matrix.
94+
func.func private @pack_main(%A: tensor<7x12xi32>, %inner_tile_size: index) -> (tensor<2x?x4x?xi32>) {
95+
// Get the size of dim (we could skip tensor.dim, but this way we can keep it generic)
96+
%c1 = arith.constant 1 : index
97+
%dim_1 = tensor.dim %A, %c1 : tensor<7x12xi32>
98+
99+
// Compute the outer-tile size corresponding to the dynamic inner tile size.
100+
// NOTE: This step is importantant. While as a user we would only tweak the
101+
// inner tile sizes, we need to make sure that the outer sizes are updated
102+
// accordingly.
103+
%outer_tile_size = arith.ceildivui %dim_1, %inner_tile_size : index
104+
105+
// NOTE: This is deliberately much larger than the input values in %A_before
106+
// so that it's easy to spot it in the output.
107+
%pad_val = arith.constant 123 : i32
108+
109+
%A_pack_empty = tensor.empty(%outer_tile_size, %inner_tile_size) : tensor<2x?x4x?xi32>
110+
111+
%A_pack = linalg.pack %A
112+
padding_value(%pad_val : i32)
113+
inner_dims_pos = [0, 1]
114+
inner_tiles = [4, %inner_tile_size]
115+
into %A_pack_empty : tensor<7x12xi32> -> tensor<2x?x4x?xi32>
116+
117+
return %A_pack : tensor<2x?x4x?xi32>
118+
}
119+
120+
/// Takes the packed matrix, unpacks it and returns the result.
121+
func.func private @unpack_main(%A_pack : tensor<2x?x4x?xi32>, %inner_tile_size: index) -> tensor<7x12xi32> {
122+
%A_unpack_empty = tensor.empty() : tensor<7x12xi32>
123+
124+
%A_unpack = linalg.unpack %A_pack
125+
inner_dims_pos = [0, 1]
126+
inner_tiles = [4, %inner_tile_size]
127+
into %A_unpack_empty : tensor<2x?x4x?xi32> -> tensor<7x12xi32>
128+
129+
return %A_unpack : tensor<7x12xi32>
130+
}
131+
132+
module @transforms attributes { transform.with_named_sequence } {
133+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consume}) {
134+
%pack = transform.structured.match ops{["linalg.pack"]} in %module : (!transform.any_op) -> !transform.any_op
135+
%unpack = transform.structured.match ops{["linalg.unpack"]} in %module : (!transform.any_op) -> !transform.any_op
136+
137+
// 1.1 Tile the linalg.pack Op so that we can decompose it into e.g. tensor.pad
138+
// and other lower-level Ops (see step 2.1)
139+
%tiled_pack_op_p, %loops_pack:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
140+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
141+
142+
// 1.2 Tile the linalg.unpack Op so that we can decompose it into e.g. tensor.pad
143+
// and other lower-level Ops (see step 2)
144+
%tiled_unpack_op_p, %loops_unpack:2 = transform.structured.tile_using_for %unpack tile_sizes [4, 1]
145+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
146+
147+
// 2.1. Decompose tiled PackOp into lower-level Ops
148+
%func_op_pack = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
149+
transform.apply_patterns to %func_op_pack {
150+
transform.apply_patterns.linalg.decompose_pack_unpack
151+
transform.apply_patterns.linalg.decompose_pad
152+
} : !transform.op<"func.func">
153+
154+
transform.apply_patterns to %func_op_pack {
155+
transform.apply_patterns.tensor.fold_tensor_subset_ops
156+
transform.apply_patterns.canonicalization
157+
} : !transform.op<"func.func">
158+
159+
// 2.1. Decompose tiled UnpackOp into lower-level Ops
160+
%func_op_unpack = transform.get_parent_op %tiled_unpack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
161+
transform.apply_patterns to %func_op_unpack {
162+
transform.apply_patterns.linalg.decompose_pack_unpack
163+
} : !transform.op<"func.func">
164+
165+
transform.apply_patterns to %func_op_unpack {
166+
transform.apply_patterns.tensor.fold_tensor_subset_ops
167+
transform.apply_patterns.canonicalization
168+
} : !transform.op<"func.func">
169+
170+
// 3. Bufferize before lowering to LLVM
171+
%bufferize = transform.bufferization.one_shot_bufferize %module
172+
{bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
173+
174+
// 4. Canonicalize
175+
%func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func">
176+
transform.apply_patterns to %func_op_bufferized {
177+
transform.apply_patterns.canonicalization
178+
} : !transform.op<"func.func">
179+
180+
transform.yield
181+
}
182+
}
183+
184+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)
185+
func.func private @setArmVLBits(%bits : i32)

0 commit comments

Comments
 (0)