Skip to content

Commit f2c557d

Browse files
authored
[SYCL] Add e2e test for llvm.scmp/ucmp.* (#15611)
Requires KhronosGroup/SPIRV-LLVM-Translator#2741 to be present in intel/llvm before merging. --------- Signed-off-by: Marcos Maronas <[email protected]>
1 parent 8a5d80f commit f2c557d

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// RUN: %{build} -Wno-error=psabi -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl/detail/core.hpp>
5+
6+
// Define vector types for different integer bit widths. We need these to
7+
// trigger llvm.scmp/ucmp for vector types. std::array or sycl::vec don't
8+
// trigger these, as they are not lowered to vector types.
9+
typedef int8_t v4i8_t __attribute__((ext_vector_type(4)));
10+
typedef int16_t v4i16_t __attribute__((ext_vector_type(4)));
11+
typedef int32_t v4i32_t __attribute__((ext_vector_type(4)));
12+
typedef int64_t v4i64_t __attribute__((ext_vector_type(4)));
13+
typedef uint8_t v4u8_t __attribute__((ext_vector_type(4)));
14+
typedef uint16_t v4u16_t __attribute__((ext_vector_type(4)));
15+
typedef uint32_t v4u32_t __attribute__((ext_vector_type(4)));
16+
typedef uint64_t v4u64_t __attribute__((ext_vector_type(4)));
17+
18+
// Check if a given type is a vector type or not. Used in submitAndCheck to
19+
// branch the check: we need element-wise comparison for vector types. Default
20+
// case: T is not a vector type.
21+
template <typename T> struct is_vector : std::false_type {};
22+
// Specialization for vector types. If T has
23+
// __attribute__((ext_vector_type(N))), then it's a vector type.
24+
template <typename T, std::size_t N>
25+
struct is_vector<T __attribute__((ext_vector_type(N)))> : std::true_type {};
26+
template <typename T> inline constexpr bool is_vector_v = is_vector<T>::value;
27+
28+
// Get the length of a vector type. Used in submitAndCheck to iterate over the
29+
// elements of the vector type. Default case: length is 1.
30+
template <typename T> struct vector_length {
31+
static constexpr std::size_t value = 1;
32+
};
33+
// Specialization for vector types. If T has
34+
// __attribute__((ext_vector_type(N))), then the length is N.
35+
template <typename T, std::size_t N>
36+
struct vector_length<T __attribute__((ext_vector_type(N)))> {
37+
static constexpr std::size_t value = N;
38+
};
39+
template <typename T>
40+
inline constexpr std::size_t vector_length_v = vector_length<T>::value;
41+
42+
// Get the element type of a vector type. Used in submitVecCombinations to
43+
// convert unsigned vector types to signed vector types for return type. Primary
44+
// template for element_type.
45+
template <typename T> struct element_type;
46+
// Specialization for vector types. If T has
47+
// __attribute__((ext_vector_type(N))), return T.
48+
template <typename T, int N>
49+
struct element_type<T __attribute__((ext_vector_type(N)))> {
50+
using type = T;
51+
};
52+
// Helper alias template.
53+
template <typename T> using element_type_t = typename element_type<T>::type;
54+
55+
// TypeList for packing the types that we want to test.
56+
// Base case for variadic template recursion.
57+
template <typename...> struct TypeList {};
58+
59+
// Function to trigger llvm.scmp/ucmp.
60+
template <typename RetTy, typename ArgTy>
61+
void compare(RetTy &res, ArgTy x, ArgTy y) {
62+
auto lessOrEq = (x <= y);
63+
auto lessThan = (x < y);
64+
res = lessOrEq ? (lessThan ? RetTy(-1) : RetTy(0)) : RetTy(1);
65+
}
66+
67+
// Function to submit kernel and check device result with host result.
68+
template <typename RetTy, typename ArgTy>
69+
void submitAndCheck(sycl::queue &q, ArgTy x, ArgTy y) {
70+
RetTy res;
71+
{
72+
sycl::buffer<RetTy, 1> res_b{&res, 1};
73+
q.submit([&](sycl::handler &cgh) {
74+
sycl::accessor acc{res_b, cgh, sycl::write_only};
75+
cgh.single_task<>([=] {
76+
RetTy tmp;
77+
compare<RetTy, ArgTy>(tmp, x, y);
78+
acc[0] = tmp;
79+
});
80+
});
81+
}
82+
RetTy expectedRes;
83+
compare<RetTy, ArgTy>(expectedRes, x, y);
84+
if constexpr (is_vector_v<RetTy>) {
85+
for (int i = 0; i < vector_length_v<RetTy>; ++i) {
86+
assert(res[i] == expectedRes[i]);
87+
}
88+
} else {
89+
assert(res == expectedRes);
90+
}
91+
}
92+
93+
// Helper to call submitAndCheck for each combination.
94+
template <typename RetTypes, typename ArgTypes>
95+
void submitAndCheckCombination(sycl::queue &q, int x, int y) {
96+
submitAndCheck<RetTypes, ArgTypes>(q, x, y);
97+
}
98+
99+
// Function to generate all the combinations possible with the two type lists.
100+
// It implements the following pseudocode :
101+
// foreach RetTy : RetTypes
102+
// foreach ArgTy : ArgTypes
103+
// submitAndCheck<RetTy, ArgTy>(q, x, y);
104+
105+
// Recursive case to generate combinations.
106+
template <typename RetType, typename... RetTypes, typename... ArgTypes>
107+
void submitCombinations(sycl::queue &q, int x, int y,
108+
TypeList<RetType, RetTypes...>, TypeList<ArgTypes...>) {
109+
(submitAndCheckCombination<RetType, ArgTypes>(q, x, y), ...);
110+
submitCombinations(q, x, y, TypeList<RetTypes...>{}, TypeList<ArgTypes...>{});
111+
}
112+
// Base case to stop recursion.
113+
template <typename... ArgTypes>
114+
void submitCombinations(sycl::queue &, int, int, TypeList<>,
115+
TypeList<ArgTypes...>) {}
116+
117+
// Function to generate all the combinations out of the given list.
118+
// It implements the following pseudocode :
119+
// foreach ArgTy : ArgTypes
120+
// submitAndCheck<ArgTy, ArgTy>(q, x, y);
121+
122+
// Recursive case to generate combinations.
123+
template <typename ArgType, typename... ArgTypes>
124+
void submitVecCombinations(sycl::queue &q, int x, int y,
125+
TypeList<ArgType, ArgTypes...>) {
126+
// Use signed types for return type, as it may return -1.
127+
using ElemType = std::make_signed_t<element_type_t<ArgType>>;
128+
using RetType =
129+
ElemType __attribute__((ext_vector_type(vector_length_v<ArgType>)));
130+
submitAndCheckCombination<RetType, ArgType>(q, x, y);
131+
submitVecCombinations(q, x, y, TypeList<ArgTypes...>{});
132+
}
133+
// Base case to stop recursion.
134+
void submitVecCombinations(sycl::queue &, int, int, TypeList<>) {}
135+
136+
int main(int argc, char **argv) {
137+
sycl::queue q;
138+
// RetTypes includes only signed types because it may return -1.
139+
using RetTypes = TypeList<int8_t, int16_t, int32_t, int64_t>;
140+
using ArgTypes = TypeList<int8_t, int16_t, int32_t, int64_t, uint8_t,
141+
uint16_t, uint32_t, uint64_t>;
142+
submitCombinations(q, 50, 49, RetTypes{}, ArgTypes{});
143+
submitCombinations(q, 50, 50, RetTypes{}, ArgTypes{});
144+
submitCombinations(q, 50, 51, RetTypes{}, ArgTypes{});
145+
using VecTypes = TypeList<v4i8_t, v4i16_t, v4i32_t, v4i64_t, v4u8_t, v4u16_t,
146+
v4u32_t, v4u64_t>;
147+
submitVecCombinations(q, 50, 49, VecTypes{});
148+
submitVecCombinations(q, 50, 50, VecTypes{});
149+
submitVecCombinations(q, 50, 51, VecTypes{});
150+
return 0;
151+
}

0 commit comments

Comments
 (0)