Skip to content

Commit d694427

Browse files
authored
[SingleSource/Vectorizer] Add unit tests for FindLastIV pattern. (#193)
This patch adds runtime test case for vectorization of FindLastIV reduction idiom: int32_t Rdx = -1; for (int32_t I = 0; I < TC; I++) { Rdx = A[I] > B[I] ? I : Rdx; } return Rdx; Improving test coverage for llvm/llvm-project#67812 and llvm/llvm-project#120395.
1 parent 72993e7 commit d694427

File tree

3 files changed

+416
-0
lines changed

3 files changed

+416
-0
lines changed

SingleSource/UnitTests/Vectorizer/common.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@
99
Init _Pragma("clang loop vectorize(enable)") Loop \
1010
};
1111

12+
#define DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(Init, Loop, Type) \
13+
auto ScalarFn = [](auto *A, auto *B, Type TC) -> Type { \
14+
Init _Pragma("clang loop vectorize(disable) interleave_count(1)") Loop \
15+
}; \
16+
auto VectorFn = [](auto *A, auto *B, Type TC) -> Type { \
17+
Init _Pragma("clang loop vectorize(enable)") Loop \
18+
};
19+
1220
#define DEFINE_SCALAR_AND_VECTOR_FN3(Loop) \
1321
auto ScalarFn = [](auto *A, auto *B, auto *C, unsigned TC) { \
1422
_Pragma("clang loop vectorize(disable) interleave_count(1)") Loop \
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
#include <algorithm>
2+
#include <functional>
3+
#include <iostream>
4+
#include <limits>
5+
#include <memory>
6+
#include <stdint.h>
7+
8+
#include "common.h"
9+
10+
template <typename RetTy, typename Ty>
11+
using Fn2Ty = std::function<RetTy(Ty *, Ty *, RetTy)>;
12+
template <typename RetTy, typename Ty>
13+
static void checkVectorFunction(Fn2Ty<RetTy, Ty> ScalarFn,
14+
Fn2Ty<RetTy, Ty> VectorFn, const char *Name) {
15+
std::cout << "Checking " << Name << "\n";
16+
17+
unsigned N = 1000;
18+
std::unique_ptr<Ty[]> Src1(new Ty[N]);
19+
std::unique_ptr<Ty[]> Src2(new Ty[N]);
20+
init_data(Src1, N);
21+
init_data(Src2, N);
22+
23+
// Test VectorFn with different input data.
24+
{
25+
// Check with random inputs.
26+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
27+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
28+
if (Reference != ToCheck) {
29+
std::cerr << "Miscompare\n";
30+
exit(1);
31+
}
32+
}
33+
34+
{
35+
// Check with Src1 > Src2 for all elements.
36+
for (unsigned I = 0; I != N; ++I) {
37+
Src1[I] = std::numeric_limits<Ty>::max();
38+
Src2[I] = std::numeric_limits<Ty>::min();
39+
}
40+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
41+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
42+
if (Reference != ToCheck) {
43+
std::cerr << "Miscompare\n";
44+
exit(1);
45+
}
46+
}
47+
48+
{
49+
// Check with Src1 < Src2 for all elements.
50+
for (unsigned I = 0; I != N; ++I) {
51+
Src1[I] = std::numeric_limits<Ty>::min();
52+
Src2[I] = std::numeric_limits<Ty>::max();
53+
}
54+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
55+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
56+
if (Reference != ToCheck) {
57+
std::cerr << "Miscompare\n";
58+
exit(1);
59+
}
60+
}
61+
62+
{
63+
// Check with only Src1[998] > Src2[998].
64+
for (unsigned I = 0; I != N; ++I)
65+
Src1[I] = Src2[I] = std::numeric_limits<Ty>::min();
66+
Src1[998] = std::numeric_limits<Ty>::max();
67+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
68+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
69+
if (Reference != ToCheck) {
70+
std::cerr << "Miscompare\n";
71+
exit(1);
72+
}
73+
}
74+
75+
{
76+
// Check with only Src1[0] > Src2[0].
77+
for (unsigned I = 0; I != N; ++I)
78+
Src1[I] = Src2[I] = std::numeric_limits<Ty>::min();
79+
Src1[0] = std::numeric_limits<Ty>::max();
80+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
81+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
82+
if (Reference != ToCheck) {
83+
std::cerr << "Miscompare\n";
84+
exit(1);
85+
}
86+
}
87+
88+
{
89+
// Check with only Src1[N - 1] > Src2[N - 1].
90+
for (unsigned I = 0; I != N; ++I)
91+
Src1[I] = Src2[I] = std::numeric_limits<Ty>::min();
92+
Src1[N - 1] = std::numeric_limits<Ty>::max();
93+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
94+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
95+
if (Reference != ToCheck) {
96+
std::cerr << "Miscompare\n";
97+
exit(1);
98+
}
99+
}
100+
101+
{
102+
// Check with only Src1[0] > Src2[0] and Src1[N - 1] > Src2[N - 1].
103+
for (unsigned I = 0; I != N; ++I)
104+
Src1[I] = Src2[I] = std::numeric_limits<Ty>::min();
105+
Src1[0] = Src1[N - 1] = std::numeric_limits<Ty>::max();
106+
auto Reference = ScalarFn(&Src1[0], &Src2[0], N);
107+
auto ToCheck = VectorFn(&Src1[0], &Src2[0], N);
108+
if (Reference != ToCheck) {
109+
std::cerr << "Miscompare\n";
110+
exit(1);
111+
}
112+
}
113+
}
114+
115+
int main(void) {
116+
rng = std::mt19937(15);
117+
118+
#define INC_COND(Start, Step, RetTy) for (RetTy I = Start; I < TC; I += Step)
119+
#define DEC_COND(End, Step, RetTy) for (RetTy I = TC; I > End; I -= Step)
120+
121+
#define DEFINE_FINDLAST_LOOP_BODY(TrueVal, FalseVal, ForCond) \
122+
ForCond { Rdx = A[I] > B[I] ? TrueVal : FalseVal; } \
123+
return Rdx;
124+
125+
{
126+
// Find the last index where A[I] > B[I] and update 32-bits Rdx when the
127+
// condition is true.
128+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
129+
int32_t Rdx = -1;,
130+
DEFINE_FINDLAST_LOOP_BODY(
131+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
132+
/* ForCond= */
133+
INC_COND(/* Start= */ 0, /* Step= */ 1, /* RetTy= */ int32_t)),
134+
int32_t);
135+
checkVectorFunction<int32_t, int32_t>(ScalarFn, VectorFn,
136+
"findlast_icmp_s32_true_update");
137+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
138+
"findlast_fcmp_s32_true_update");
139+
}
140+
141+
{
142+
// Find the last index where A[I] > B[I] and update 16-bits Rdx when the
143+
// condition is true.
144+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
145+
int16_t Rdx = -1;,
146+
DEFINE_FINDLAST_LOOP_BODY(
147+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
148+
/* ForCond= */
149+
INC_COND(/* Start= */ 0, /* Step= */ 1, /* RetTy= */ int16_t)),
150+
int16_t);
151+
checkVectorFunction<int16_t, int16_t>(ScalarFn, VectorFn,
152+
"findlast_icmp_s16_true_update");
153+
}
154+
155+
{
156+
// Update 32-bits Rdx when the condition A[I] > B[I] is false.
157+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
158+
int32_t Rdx = -1;,
159+
DEFINE_FINDLAST_LOOP_BODY(
160+
/* TrueVal= */ Rdx, /* FalseVal= */ I,
161+
/* ForCond= */
162+
INC_COND(/* Start= */ 0, /* Step= */ 1, /* RetTy= */ int32_t)),
163+
int32_t);
164+
checkVectorFunction<int32_t, int32_t>(ScalarFn, VectorFn,
165+
"findlast_icmp_s32_false_update");
166+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
167+
"findlast_fcmp_s32_false_update");
168+
}
169+
170+
{
171+
// Update 16-bits Rdx when the condition A[I] > B[I] is false.
172+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
173+
int16_t Rdx = -1;,
174+
DEFINE_FINDLAST_LOOP_BODY(
175+
/* TrueVal= */ Rdx, /* FalseVal= */ I,
176+
/* ForCond= */
177+
INC_COND(/* Start= */ 0, /* Step= */ 1, /* RetTy= */ int16_t)),
178+
int16_t);
179+
checkVectorFunction<int16_t, int16_t>(ScalarFn, VectorFn,
180+
"findlast_icmp_s16_false_update");
181+
}
182+
183+
{
184+
// Find the last 32-bits index with the start value TC.
185+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
186+
int32_t Rdx = TC;,
187+
DEFINE_FINDLAST_LOOP_BODY(
188+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
189+
/* ForCond= */
190+
INC_COND(/* Start= */ 0, /* Step= */ 1, /* RetTy= */ int32_t)),
191+
int32_t);
192+
checkVectorFunction<int32_t, int32_t>(ScalarFn, VectorFn,
193+
"findlast_icmp_s32_start_TC");
194+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
195+
"findlast_fcmp_s32_start_TC");
196+
}
197+
198+
{
199+
// Find the last 16-bits index with the start value TC.
200+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
201+
int16_t Rdx = TC;,
202+
DEFINE_FINDLAST_LOOP_BODY(
203+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
204+
/* ForCond= */
205+
INC_COND(/* Start= */ 0, /* Step= */ 1, /* RetTy= */ int16_t)),
206+
int16_t);
207+
checkVectorFunction<int16_t, int16_t>(ScalarFn, VectorFn,
208+
"findlast_icmp_s16_start_TC");
209+
}
210+
211+
{
212+
// Increment the 32-bits induction variable by 2.
213+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
214+
int32_t Rdx = -1;,
215+
DEFINE_FINDLAST_LOOP_BODY(
216+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
217+
/* ForCond= */
218+
INC_COND(/* Start= */ 0, /* Step= */ 2, /* RetTy= */ int32_t)),
219+
int32_t);
220+
checkVectorFunction<int32_t, int32_t>(ScalarFn, VectorFn,
221+
"findlast_icmp_s32_inc_2");
222+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
223+
"findlast_fcmp_s32_inc_2");
224+
}
225+
226+
{
227+
// Increment the 16-bits induction variable by 2.
228+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
229+
int16_t Rdx = -1;,
230+
DEFINE_FINDLAST_LOOP_BODY(
231+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
232+
/* ForCond= */
233+
INC_COND(/* Start= */ 0, /* Step= */ 2, /* RetTy= */ int16_t)),
234+
int16_t);
235+
checkVectorFunction<int16_t, int16_t>(ScalarFn, VectorFn,
236+
"findlast_icmp_s16_inc_2");
237+
}
238+
239+
{
240+
// Check with decreasing 32-bits induction variable.
241+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
242+
int32_t Rdx = -1;,
243+
DEFINE_FINDLAST_LOOP_BODY(
244+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
245+
/* ForCond= */
246+
DEC_COND(/* End= */ 0, /* Step= */ 1, /* RetTy= */ int32_t)),
247+
int32_t);
248+
checkVectorFunction<int32_t, int32_t>(
249+
ScalarFn, VectorFn, "findlast_icmp_s32_start_decreasing_induction");
250+
checkVectorFunction<int32_t, float>(
251+
ScalarFn, VectorFn, "findlast_fcmp_s32_start_decreasing_induction");
252+
}
253+
254+
{
255+
// Check with decreasing 16-bits induction variable.
256+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
257+
int16_t Rdx = -1;,
258+
DEFINE_FINDLAST_LOOP_BODY(
259+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
260+
/* ForCond= */
261+
DEC_COND(/* End= */ 0, /* Step= */ 1, /* RetTy= */ int16_t)),
262+
int16_t);
263+
checkVectorFunction<int16_t, int16_t>(
264+
ScalarFn, VectorFn, "findlast_icmp_s16_start_decreasing_induction");
265+
}
266+
267+
{
268+
// Check with 32-bits the induction variable starts from 3.
269+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
270+
int32_t Rdx = -1;,
271+
DEFINE_FINDLAST_LOOP_BODY(
272+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
273+
/* ForCond= */
274+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int32_t)),
275+
int32_t);
276+
checkVectorFunction<int32_t, int32_t>(ScalarFn, VectorFn,
277+
"findlast_icmp_s32_iv_start_3");
278+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
279+
"findlast_fcmp_s32_iv_start_3");
280+
}
281+
282+
{
283+
// Check with 16-bits the induction variable starts from 3.
284+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
285+
int16_t Rdx = -1;,
286+
DEFINE_FINDLAST_LOOP_BODY(
287+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
288+
/* ForCond= */
289+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int16_t)),
290+
int16_t);
291+
checkVectorFunction<int16_t, int16_t>(ScalarFn, VectorFn,
292+
"findlast_icmp_s16_iv_start_3");
293+
}
294+
295+
{
296+
// Check with start value of 3 and 32-bits induction variable starts at 3.
297+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
298+
int32_t Rdx = 3;,
299+
DEFINE_FINDLAST_LOOP_BODY(
300+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
301+
/* ForCond= */
302+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int32_t)),
303+
int32_t);
304+
checkVectorFunction<int32_t, int32_t>(
305+
ScalarFn, VectorFn, "findlast_icmp_s32_start_3_iv_start_3");
306+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
307+
"findlast_fcmp_s32_start_3_iv_start_3");
308+
}
309+
310+
{
311+
// Check with start value of 3 and 16-bits induction variable starts at 3.
312+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
313+
int16_t Rdx = 3;,
314+
DEFINE_FINDLAST_LOOP_BODY(
315+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
316+
/* ForCond= */
317+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int16_t)),
318+
int16_t);
319+
checkVectorFunction<int16_t, int16_t>(
320+
ScalarFn, VectorFn, "findlast_icmp_s16_start_3_iv_start_3");
321+
}
322+
323+
{
324+
// Check with start value of 2 and 32-bits induction variable starts at 3.
325+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
326+
int32_t Rdx = 2;,
327+
DEFINE_FINDLAST_LOOP_BODY(
328+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
329+
/* ForCond= */
330+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int32_t)),
331+
int32_t);
332+
checkVectorFunction<int32_t, int32_t>(
333+
ScalarFn, VectorFn, "findlast_icmp_s32_start_2_iv_start_3");
334+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
335+
"findlast_fcmp_s32_start_2_iv_start_3");
336+
}
337+
338+
{
339+
// Check with start value of 2 and 16-bits induction variable starts at 3.
340+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
341+
int16_t Rdx = 2;,
342+
DEFINE_FINDLAST_LOOP_BODY(
343+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
344+
/* ForCond= */
345+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int16_t)),
346+
int16_t);
347+
checkVectorFunction<int16_t, int16_t>(
348+
ScalarFn, VectorFn, "findlast_icmp_s16_start_2_iv_start_3");
349+
}
350+
351+
{
352+
// Check with start value of 4 and 32-bits induction variable starts at 3.
353+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
354+
int32_t Rdx = 4;,
355+
DEFINE_FINDLAST_LOOP_BODY(
356+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
357+
/* ForCond= */
358+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int32_t)),
359+
int32_t);
360+
checkVectorFunction<int32_t, int32_t>(
361+
ScalarFn, VectorFn, "findlast_icmp_s32_start_4_iv_start_3");
362+
checkVectorFunction<int32_t, float>(ScalarFn, VectorFn,
363+
"findlast_fcmp_s32_start_4_iv_start_3");
364+
}
365+
366+
{
367+
// Check with start value of 4 and 16-bits induction variable starts at 3.
368+
DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(
369+
int16_t Rdx = 4;,
370+
DEFINE_FINDLAST_LOOP_BODY(
371+
/* TrueVal= */ I, /* FalseVal= */ Rdx,
372+
/* ForCond= */
373+
INC_COND(/* Start= */ 3, /* Step= */ 1, /* RetTy= */ int16_t)),
374+
int16_t);
375+
checkVectorFunction<int16_t, int16_t>(
376+
ScalarFn, VectorFn, "findlast_icmp_s16_start_4_iv_start_3");
377+
}
378+
379+
return 0;
380+
}

0 commit comments

Comments
 (0)