Skip to content

Commit b1900aa

Browse files
committed
separate linearization tests
1 parent c024969 commit b1900aa

File tree

2 files changed

+407
-239
lines changed

2 files changed

+407
-239
lines changed
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
2+
// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
3+
4+
// ALL-LABEL: test_linearize
5+
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
6+
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
7+
8+
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
9+
// BW-128: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
10+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
11+
12+
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
13+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
14+
15+
// BW-128: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
16+
// BW-0: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
17+
%1 = math.sin %arg0 : vector<2x2xf32>
18+
19+
// BW-128: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32>
20+
// BW-0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
21+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
22+
23+
// ALL: return %[[RES]] : vector<2x2xf32>
24+
return %0 : vector<2x2xf32>
25+
}
26+
27+
// -----
28+
29+
// ALL-LABEL: test_linearize_poison
30+
func.func @test_linearize_poison() -> vector<2x2xf32> {
31+
32+
// BW-128: %[[POISON:.*]] = ub.poison : vector<4xf32>
33+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
34+
35+
// BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
36+
%0 = ub.poison : vector<2x2xf32>
37+
// ALL: return %[[RES]] : vector<2x2xf32>
38+
return %0 : vector<2x2xf32>
39+
}
40+
41+
// -----
42+
43+
// ALL-LABEL: test_partial_linearize
44+
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
45+
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
46+
47+
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
48+
// BW-128: %[[CST:.*]] = arith.constant dense<{{.*}}> : vector<4xf32>
49+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
50+
51+
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
52+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
53+
54+
// BW-128: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
55+
// BW-0: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
56+
%5 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0,3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 5.0, 6.0]]> : vector<4x4xf32>
57+
58+
// Arith and math ops are handled in generic way, check some of them
59+
// BW-128: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32>
60+
// BW-0: %[[SIN:.*]] = math.sin %[[ORIG_ARG]] : vector<2x2xf32>
61+
%1 = math.sin %arg0 : vector<2x2xf32>
62+
63+
// BW-128: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
64+
// BW-0: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
65+
%6 = math.sin %arg1 : vector<4x4xf32>
66+
67+
// BW-128: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32>
68+
// BW-0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
69+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
70+
71+
// BW-128: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
72+
// BW-0: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
73+
%7 = arith.addf %arg1, %5 : vector<4x4xf32>
74+
75+
// ALL: return %[[RES]] : vector<2x2xf32>
76+
return %0 : vector<2x2xf32>
77+
}
78+
79+
// -----
80+
81+
// ALL-LABEL: test_index_no_linearize
82+
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
83+
// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
84+
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
85+
return %0 : vector<2x2xindex>
86+
}
87+
88+
// -----
89+
90+
// ALL-LABEL: func.func @test_scalable_linearize(
91+
// ALL-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
92+
func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
93+
// BW-128: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
94+
// BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
95+
// BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[2]xf32>
96+
%0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32>
97+
98+
// BW-128: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
99+
// BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[2]xf32>
100+
%1 = math.sin %arg0 : vector<2x[2]xf32>
101+
102+
// BW-128: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
103+
// BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[2]xf32>
104+
%2 = arith.addf %0, %1 : vector<2x[2]xf32>
105+
106+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
107+
// ALL: return %[[RES]] : vector<2x[2]xf32>
108+
return %2 : vector<2x[2]xf32>
109+
}
110+
111+
112+
// -----
113+
114+
// ALL-LABEL: test_extract_strided_slice_1
115+
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
116+
func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
117+
118+
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
119+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
120+
// BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
121+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
122+
// BW-128: return %[[RES]] : vector<2x2xf32>
123+
124+
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
125+
// BW-0: return %[[RES]] : vector<2x2xf32>
126+
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
127+
: vector<4x8xf32> to vector<2x2xf32>
128+
return %0 : vector<2x2xf32>
129+
}
130+
131+
132+
// -----
133+
134+
// ALL-LABEL: test_extract_strided_slice_2
135+
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
136+
func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
137+
138+
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
139+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
140+
// BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
141+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
142+
// BW-128: return %[[RES]] : vector<1x4x2xf32>
143+
144+
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
145+
// BW-0: return %[[RES]] : vector<1x4x2xf32>
146+
%0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
147+
: vector<2x8x2xf32> to vector<1x4x2xf32>
148+
return %0 : vector<1x4x2xf32>
149+
}
150+
151+
// -----
152+
153+
// ALL-LABEL: test_vector_shuffle
154+
// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
155+
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
156+
157+
// BW-128-DAG: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
158+
// BW-128-DAG: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
159+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
160+
// BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
161+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
162+
// BW-128: return %[[RES]] : vector<8x2xf32>
163+
164+
// BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
165+
// BW-0: return %[[RES]] : vector<8x2xf32>
166+
%0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
167+
return %0 : vector<8x2xf32>
168+
}
169+
170+
// -----
171+
172+
// ALL-LABEL: test_vector_extract
173+
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
174+
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
175+
176+
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
177+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
178+
// BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
179+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
180+
// BW-128: return %[[RES]] : vector<8x2xf32>
181+
182+
// BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32>
183+
// BW-0: return %[[RES]] : vector<8x2xf32>
184+
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
185+
return %0 : vector<8x2xf32>
186+
}
187+
188+
// -----
189+
190+
// ALL-LABEL: test_vector_insert
191+
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
192+
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
193+
194+
// BW-128-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
195+
// BW-128-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
196+
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
197+
// BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
198+
// BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
199+
// BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
200+
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
201+
// BW-128: return %[[RES]] : vector<2x8x4xf32>
202+
203+
// BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
204+
// BW-0: return %[[RES]] : vector<2x8x4xf32>
205+
206+
%0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
207+
return %0 : vector<2x8x4xf32>
208+
}
209+
210+
// -----
211+
212+
// ALL-LABEL: test_vector_bitcast
213+
// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
214+
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
215+
216+
// BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
217+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
218+
%1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
219+
return %1 : vector<4x8xf16>
220+
}
221+
222+
// -----
223+
224+
// ALL-LABEL: test_vector_bitcast
225+
// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32>
226+
func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
227+
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
228+
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
229+
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
230+
231+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16>
232+
%1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
233+
return %1 : vector<4x4xf16>
234+
}
235+
236+
// -----
237+
238+
// ALL-LABEL: test_vector_bitcast
239+
// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32>
240+
func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
241+
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
242+
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
243+
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
244+
245+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16>
246+
%1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
247+
return %1 : vector<4x[4]xf16>
248+
}
249+
250+
// -----
251+
252+
// ALL-LABEL: test_vector_bitcast
253+
// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
254+
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
255+
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
256+
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
257+
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
258+
259+
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16>
260+
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
261+
return %1 : vector<[4]x4xf16>
262+
}

0 commit comments

Comments
 (0)