Skip to content

Commit 1207e22

Browse files
committed
reduce duplication in testing, test just bitwidth specific logic
1 parent 1550605 commit 1207e22

File tree

1 file changed

+15
-220
lines changed

1 file changed

+15
-220
lines changed
Lines changed: 15 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -1,191 +1,42 @@
11
// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=128 -verify-diagnostics | FileCheck %s --check-prefixes=ALL,BW-128
2-
// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
3-
4-
// ALL-LABEL: test_linearize
5-
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
6-
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
7-
8-
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
9-
// BW-128: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
10-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
11-
12-
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
13-
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
142

15-
// BW-128: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
16-
// BW-0: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
17-
%1 = math.sin %arg0 : vector<2x2xf32>
18-
19-
// BW-128: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32>
20-
// BW-0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
21-
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
22-
23-
// ALL: return %[[RES]] : vector<2x2xf32>
24-
return %0 : vector<2x2xf32>
25-
}
26-
27-
// -----
28-
29-
// ALL-LABEL: test_linearize_poison
30-
func.func @test_linearize_poison() -> vector<2x2xf32> {
31-
32-
// BW-128: %[[POISON:.*]] = ub.poison : vector<4xf32>
33-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[POISON]] : vector<4xf32> to vector<2x2xf32>
34-
35-
// BW-0: %[[RES:.*]] = ub.poison : vector<2x2xf32>
36-
%0 = ub.poison : vector<2x2xf32>
37-
// ALL: return %[[RES]] : vector<2x2xf32>
38-
return %0 : vector<2x2xf32>
39-
}
3+
// RUN: mlir-opt %s -split-input-file -test-bit-width-constrained-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefixes=ALL,BW-0
404

41-
// -----
425

43-
// ALL-LABEL: test_partial_linearize
44-
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>)
45-
func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> {
6+
// A vector<2x2xf32> has inner-most dimension with 64-bits. Check that at
7+
// bitwidth threshold 128 (>= 64), operations are linearized, and at
8+
// bitwidth threshold 0 (< 64), operations are linearized.
469

47-
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
48-
// BW-128: %[[CST:.*]] = arith.constant dense<{{.*}}> : vector<4xf32>
49-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
10+
// ALL-LABEL: test_result_bitwidth_64
11+
func.func @test_result_bitwidth_64(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
5012

51-
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
13+
// BW-128: arith.constant {{.*}} vector<4xf32>
14+
// BW-0: arith.constant {{.*}} vector<2x2xf32>
5215
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
5316

54-
// BW-128: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
55-
// BW-0: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32>
56-
%5 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0,3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 5.0, 6.0]]> : vector<4x4xf32>
57-
58-
// Arith and math ops are handled in generic way, check some of them
59-
// BW-128: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32>
60-
// BW-0: %[[SIN:.*]] = math.sin %[[ORIG_ARG]] : vector<2x2xf32>
17+
// BW-128: math.sin {{.*}} vector<4xf32>
18+
// BW-0: math.sin {{.*}} vector<2x2xf32>
6119
%1 = math.sin %arg0 : vector<2x2xf32>
6220

63-
// BW-128: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
64-
// BW-0: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32>
65-
%6 = math.sin %arg1 : vector<4x4xf32>
66-
67-
// BW-128: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32>
68-
// BW-0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
69-
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
70-
71-
// BW-128: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
72-
// BW-0: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32>
73-
%7 = arith.addf %arg1, %5 : vector<4x4xf32>
74-
75-
// ALL: return %[[RES]] : vector<2x2xf32>
7621
return %0 : vector<2x2xf32>
7722
}
7823

7924
// -----
8025

8126
// ALL-LABEL: test_index_no_linearize
8227
func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> {
28+
8329
// BW-128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
30+
// BW-0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex>
8431
%0 = arith.addi %arg0, %arg1 : vector<2x2xindex>
8532
return %0 : vector<2x2xindex>
8633
}
8734

8835
// -----
8936

90-
// ALL-LABEL: func.func @test_scalable_linearize(
91-
// ALL-SAME: %[[ARG_0:.*]]: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
92-
func.func @test_scalable_linearize(%arg0: vector<2x[2]xf32>) -> vector<2x[2]xf32> {
93-
// BW-128: %[[SC:.*]] = vector.shape_cast %[[ARG_0]] : vector<2x[2]xf32> to vector<[4]xf32>
94-
// BW-128: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<[4]xf32>
95-
// BW-0: %[[CST:.*]] = arith.constant dense<3.000000e+00> : vector<2x[2]xf32>
96-
%0 = arith.constant dense<[[3., 3.], [3., 3.]]> : vector<2x[2]xf32>
97-
98-
// BW-128: %[[SIN:.*]] = math.sin %[[SC]] : vector<[4]xf32>
99-
// BW-0: %[[SIN:.*]] = math.sin %[[ARG_0]] : vector<2x[2]xf32>
100-
%1 = math.sin %arg0 : vector<2x[2]xf32>
101-
102-
// BW-128: %[[ADDF:.*]] = arith.addf %[[SIN]], %[[CST]] : vector<[4]xf32>
103-
// BW-0: %[[RES:.*]] = arith.addf %[[CST]], %[[SIN]] : vector<2x[2]xf32>
104-
%2 = arith.addf %0, %1 : vector<2x[2]xf32>
105-
106-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[ADDF]] : vector<[4]xf32> to vector<2x[2]xf32>
107-
// ALL: return %[[RES]] : vector<2x[2]xf32>
108-
return %2 : vector<2x[2]xf32>
109-
}
110-
111-
112-
// -----
113-
114-
// ALL-LABEL: test_extract_strided_slice_1
115-
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
116-
func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
117-
118-
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
119-
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
120-
// BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
121-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
122-
// BW-128: return %[[RES]] : vector<2x2xf32>
123-
124-
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
125-
// BW-0: return %[[RES]] : vector<2x2xf32>
126-
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
127-
: vector<4x8xf32> to vector<2x2xf32>
128-
return %0 : vector<2x2xf32>
129-
}
130-
131-
132-
// -----
133-
134-
// ALL-LABEL: test_extract_strided_slice_2
135-
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
136-
func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
137-
138-
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
139-
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
140-
// BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
141-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
142-
// BW-128: return %[[RES]] : vector<1x4x2xf32>
143-
144-
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
145-
// BW-0: return %[[RES]] : vector<1x4x2xf32>
146-
%0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
147-
: vector<2x8x2xf32> to vector<1x4x2xf32>
148-
return %0 : vector<1x4x2xf32>
149-
}
150-
151-
// -----
152-
153-
// ALL-LABEL: test_vector_shuffle
154-
// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
155-
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
156-
157-
// BW-128-DAG: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
158-
// BW-128-DAG: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
159-
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
160-
// BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
161-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
162-
// BW-128: return %[[RES]] : vector<8x2xf32>
163-
164-
// BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
165-
// BW-0: return %[[RES]] : vector<8x2xf32>
166-
%0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
167-
return %0 : vector<8x2xf32>
168-
}
169-
170-
// -----
171-
172-
// ALL-LABEL: test_vector_extract
173-
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
174-
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
175-
176-
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
177-
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
178-
// BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
179-
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
180-
// BW-128: return %[[RES]] : vector<8x2xf32>
181-
182-
// BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32>
183-
// BW-0: return %[[RES]] : vector<8x2xf32>
184-
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
185-
return %0 : vector<8x2xf32>
186-
}
187-
188-
// -----
37+
// The logic for the insert op with regards to the bitwidth threshold is
38+
// different to the other ops, so we test it here. Specifically, the logic
39+
// is based on the bitwidth of the value to store.
18940

19041
// ALL-LABEL: test_vector_insert
19142
// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -194,9 +45,6 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
19445
// BW-128-DAG: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
19546
// BW-128-DAG: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
19647
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
197-
// BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
198-
// BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
199-
// BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
20048
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
20149
// BW-128: return %[[RES]] : vector<2x8x4xf32>
20250

@@ -207,56 +55,3 @@ func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>)
20755
return %0 : vector<2x8x4xf32>
20856
}
20957

210-
// -----
211-
212-
// ALL-LABEL: test_vector_bitcast
213-
// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32>
214-
func.func @test_vector_bitcast(%arg0: vector<4x4xf32>) -> vector<4x8xf16> {
215-
216-
// BW-128: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
217-
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x4xf32> to vector<4x8xf16>
218-
%1 = vector.bitcast %arg0 : vector<4x4xf32> to vector<4x8xf16>
219-
return %1 : vector<4x8xf16>
220-
}
221-
222-
// -----
223-
224-
// ALL-LABEL: test_vector_bitcast
225-
// ALL-SAME: %[[ARG_0:.*]]: vector<4x2xf32>
226-
func.func @test_vector_bitcast(%arg0: vector<4x2xf32>) -> vector<4x4xf16> {
227-
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x2xf32> to vector<8xf32>
228-
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<8xf32> to vector<16xf16>
229-
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<16xf16> to vector<4x4xf16>
230-
231-
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x2xf32> to vector<4x4xf16>
232-
%1 = vector.bitcast %arg0 : vector<4x2xf32> to vector<4x4xf16>
233-
return %1 : vector<4x4xf16>
234-
}
235-
236-
// -----
237-
238-
// ALL-LABEL: test_vector_bitcast
239-
// ALL-SAME: %[[ARG_0:.*]]: vector<4x[2]xf32>
240-
func.func @test_vector_bitcast(%arg0: vector<4x[2]xf32>) -> vector<4x[4]xf16> {
241-
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<4x[2]xf32> to vector<[8]xf32>
242-
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
243-
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<4x[4]xf16>
244-
245-
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4x[2]xf32> to vector<4x[4]xf16>
246-
%1 = vector.bitcast %arg0 : vector<4x[2]xf32> to vector<4x[4]xf16>
247-
return %1 : vector<4x[4]xf16>
248-
}
249-
250-
// -----
251-
252-
// ALL-LABEL: test_vector_bitcast
253-
// ALL-SAME: %[[ARG_0:.*]]: vector<[4]x2xf32>
254-
func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> {
255-
// BW-128: %[[DOWNCAST:.*]] = vector.shape_cast %[[ARG_0]] : vector<[4]x2xf32> to vector<[8]xf32>
256-
// BW-128: %[[BITCAST:.*]] = vector.bitcast %[[DOWNCAST]] : vector<[8]xf32> to vector<[16]xf16>
257-
// BW-128: %[[UPCAST:.*]] = vector.shape_cast %[[BITCAST]] : vector<[16]xf16> to vector<[4]x4xf16>
258-
259-
// BW-0: %[[BITCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<[4]x2xf32> to vector<[4]x4xf16>
260-
%1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16>
261-
return %1 : vector<[4]x4xf16>
262-
}

0 commit comments

Comments
 (0)