Skip to content

Commit 9f1c8b1

Browse files
authored
[mlir][tensor][SVE] Add e2e test for tensor.pack targeting SVE (llvm#119692)
1 parent b558c6b commit 9f1c8b1

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// REQUIRES: arm-emulator
2+
3+
// This test is a clone of pack-dynamic-inner-tile.mlir, but the inner tile is
4+
// vector.vscale * %c8 rather than %c8. In order to demonstrate the impact of
5+
// using scalable vectors, vscale is set to 2 so that that the run-time tile
6+
// size is [16, 1] rather than [8, 1].
7+
//
8+
// Note that you can also tweak the size of vscale by passing this flag to
9+
// QEMU:
10+
// * -cpu max,sve-max-vq=[1-16]
11+
// (select the value between 1 and 16).
12+
13+
// DEFINE: %{compile} = mlir-opt %s \
14+
// DEFINE: --transform-interpreter --test-transform-dialect-erase-schedule \
15+
// DEFINE: --lower-vector-mask \
16+
// DEFINE: -canonicalize -cse --convert-vector-to-scf \
17+
// DEFINE: -arm-sve-legalize-vector-storage -convert-vector-to-llvm="enable-arm-sve" -test-lower-to-llvm -o %t
18+
19+
// DEFINE: %{entry_point} = main
20+
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve"\
21+
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
22+
23+
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
24+
25+
/// End-to-end test for tensor.pack where one of the inner tile sizes is
26+
/// scalable.
27+
28+
func.func @main() {
29+
// Allocate and initialise the inputs
30+
%A_alloc = tensor.empty() : tensor<7x16xi32>
31+
32+
%A = arith.constant dense<[
33+
[ 1, 8, 15, 22, 29, 36, 43, 50, 57, 64, 71, 78, 85, 92, 99 , 106],
34+
[ 2, 9, 16, 23, 30, 37, 44, 51, 58, 65, 72, 79, 86, 93, 100, 107],
35+
[ 3, 10, 17, 24, 31, 38, 45, 52, 59, 66, 73, 80, 87, 94, 101, 108],
36+
[ 4, 11, 18, 25, 32, 39, 46, 53, 60, 67, 74, 81, 88, 95, 102, 109],
37+
[ 5, 12, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89, 96, 103, 110],
38+
[ 6, 13, 20, 27, 34, 41, 48, 55, 62, 69, 76, 83, 90, 97, 104, 111],
39+
[ 7, 14, 21, 28, 35, 42, 49, 56, 63, 70, 77, 84, 91, 98, 105, 112]
40+
]> : tensor<7x16xi32>
41+
42+
func.call @pack(%A) : (tensor<7x16xi32>) -> ()
43+
44+
return
45+
}
46+
47+
func.func private @pack(%A: tensor<7x16xi32>) {
48+
%c1 = arith.constant 1 : index
49+
%pad_val = arith.constant 123 : i32
50+
51+
// Set vscale to 2 (vector width = 256). This will have identical effect to:
52+
// * qemu-aarch64 -cpu max,sve-max-vq=2 (...)
53+
%c256 = arith.constant 256 : i32
54+
func.call @setArmVLBits(%c256) : (i32) -> ()
55+
56+
// Scalable tile size
57+
%vs = vector.vscale
58+
%c8 = arith.constant 8 : index
59+
%tile_size = arith.muli %c8, %vs : index
60+
61+
%A_pack_empty = tensor.empty(%c1, %tile_size) : tensor<?x16x?x1xi32>
62+
63+
%A_pack = tensor.pack %A
64+
padding_value(%pad_val : i32)
65+
inner_dims_pos = [0, 1]
66+
inner_tiles = [%tile_size, 1]
67+
into %A_pack_empty : tensor<7x16xi32> -> tensor<?x16x?x1xi32>
68+
69+
%A_cast = tensor.cast %A_pack : tensor<?x16x?x1xi32> to tensor<*xi32>
70+
71+
// Print the results
72+
// CHECK: Unranked Memref base@ = 0{{.*}} rank = 4 offset = 0 sizes = [1, 16, 16, 1] strides = [256, 16, 1, 1] data =
73+
// Tile 1: ((vscale x 8) x 1)
74+
// CHECK-NEXT: 1
75+
// CHECK-NEXT: 2
76+
// CHECK-NEXT: 3
77+
// CHECK-NEXT: 4
78+
// CHECK-NEXT: 5
79+
// CHECK-NEXT: 6
80+
// CHECK-NEXT: 7
81+
// Expect pad value after 7 elements
82+
// CHECK-NEXT: 123
83+
// CHECK-NEXT: 123
84+
// CHECK-NEXT: 123
85+
// CHECK-NEXT: 123
86+
// CHECK-NEXT: 123
87+
// CHECK-NEXT: 123
88+
// CHECK-NEXT: 123
89+
// CHECK-NEXT: 123
90+
// CHECK-NEXT: 123
91+
// Tile 2: ((vscale x 8) x 1)
92+
// CHECK-NEXT: 8
93+
// CHECK-NEXT: 9
94+
// CHECK-NEXT: 10
95+
// CHECK-NEXT: 11
96+
// CHECK-NEXT: 12
97+
// CHECK-NEXT: 13
98+
// CHECK-NEXT: 14
99+
// Expect pad value after further 7 elements
100+
// CHECK-NEXT: 123
101+
// CHECK-NEXT: 123
102+
// CHECK-NEXT: 123
103+
// CHECK-NEXT: 123
104+
// CHECK-NEXT: 123
105+
// CHECK-NEXT: 123
106+
// CHECK-NEXT: 123
107+
// CHECK-NEXT: 123
108+
// CHECK-NEXT: 123
109+
// Tile 3: ((vscale x 8) x 1)
110+
// CHECK-NEXT: 15
111+
// CHECK-NEXT: 16
112+
// ...
113+
call @printMemrefI32(%A_cast) : (tensor<*xi32>) -> ()
114+
115+
return
116+
}
117+
118+
module @transforms attributes { transform.with_named_sequence } {
119+
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.consume}) {
120+
%pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op
121+
122+
// 1. Tile so that we can decompose tensor.pack into tensor.pad and other
123+
// Ops (see step 2)
124+
%tiled_pack_op_p, %loops:2 = transform.structured.tile_using_for %pack tile_sizes [1, 1]
125+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
126+
127+
// 2. Decompose the tiled pack Op into (trimmed for brevity):
128+
//
129+
// %padded = tensor.pad %slice_of_A (..) :
130+
// tensor<?x?xi32> to tensor<8x1xi32>
131+
// %inserted_slice = tensor.insert_slice %padded into %slice_of_A_pack (...) :
132+
// tensor<8x1xi32> into tensor<1x1x?x1xi32>
133+
//
134+
// (NOTE: no tile is transposed, hence no linalg.transpose)
135+
//
136+
// This is followed by this decomposition of the pad Op:
137+
//
138+
// %c123_i32 = arith.constant 123 : i32
139+
// %slice_of_A = tensor.extract_slice %A[%3, %arg3] [%4, %5] [1, 1] :
140+
// tensor<7x16xi32> to tensor<?x?xi32>
141+
// %empty = tensor.empty() : tensor<8x1xi32>
142+
// %fill = linalg.fill ins(%c123_i32 : i32) outs(%empty :
143+
// tensor<8x1xi32>) -> tensor<8x1xi32>
144+
// %inserted_slice = tensor.insert_slice %slice_of_A into %fill[0, 0] [%4, %5] [1, 1] :
145+
// tensor<?x?xi32> into tensor<8x1xi32>
146+
//
147+
%func_op = transform.get_parent_op %tiled_pack_op_p {isolated_from_above} : (!transform.any_op) -> !transform.op<"func.func">
148+
transform.apply_patterns to %func_op {
149+
transform.apply_patterns.linalg.decompose_pack_unpack
150+
transform.apply_patterns.linalg.decompose_pad
151+
} : !transform.op<"func.func">
152+
153+
// 3. Vectorize linalg.fill.
154+
// Vector sizes match the inner tiles in the payload IR.
155+
%fill = transform.structured.match ops{["linalg.fill"]} in %func_op : (!transform.op<"func.func">) -> !transform.any_op
156+
transform.structured.vectorize %fill vector_sizes [[8], 1] : !transform.any_op
157+
158+
transform.apply_patterns to %func_op {
159+
transform.apply_patterns.tensor.fold_tensor_subset_ops
160+
transform.apply_patterns.canonicalization
161+
} : !transform.op<"func.func">
162+
163+
// 3. Bufferize before lowering to LLVM
164+
%bufferize = transform.bufferization.one_shot_bufferize %module
165+
{bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
166+
167+
// 4. Canonicalize + rank-reducing patters (to get rid of the trailing unit
168+
// dim).
169+
%func_op_bufferized = transform.structured.match ops{["func.func"]} in %bufferize : (!transform.any_op) -> !transform.op<"func.func">
170+
transform.apply_patterns to %func_op_bufferized {
171+
transform.apply_patterns.vector.rank_reducing_subview_patterns
172+
transform.apply_patterns.vector.drop_unit_dims_with_shape_cast
173+
transform.apply_patterns.canonicalization
174+
} : !transform.op<"func.func">
175+
176+
transform.yield
177+
}
178+
}
179+
180+
func.func private @printMemrefI32(%ptr : tensor<*xi32>)
181+
func.func private @setArmVLBits(%bits : i32)

0 commit comments

Comments
 (0)