Skip to content

Commit eb3ac32

Browse files
alexeyvoronov-intelromanovvlad
authored andcommitted
[SYCL] Add vector_size check for built-in's functions.
To prevent calls of built-in functions with vectors of mismatched length. Signed-off-by: Alexey Voronov <[email protected]>
1 parent 8c2e09a commit eb3ac32

File tree

5 files changed

+124
-20
lines changed

5 files changed

+124
-20
lines changed

sycl/include/CL/sycl/builtins.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ template <typename T, typename T2>
232232
detail::enable_if_t<
233233
detail::is_genfloat<T>::value && detail::is_genfloatptr<T2>::value, T>
234234
fract(T x, T2 iptr) __NOEXC {
235+
detail::check_vector_size<T, T2>();
235236
return __sycl_std::__invoke_fract<T>(x, iptr);
236237
}
237238

@@ -240,6 +241,7 @@ template <typename T, typename T2>
240241
detail::enable_if_t<
241242
detail::is_genfloat<T>::value && detail::is_genintptr<T2>::value, T>
242243
frexp(T x, T2 exp) __NOEXC {
244+
detail::check_vector_size<T, T2>();
243245
return __sycl_std::__invoke_frexp<T>(x, exp);
244246
}
245247

@@ -277,6 +279,7 @@ template <typename T, typename T2>
277279
detail::enable_if_t<
278280
detail::is_vgenfloat<T>::value && detail::is_intn<T2>::value, T>
279281
ldexp(T x, T2 k) __NOEXC {
282+
detail::check_vector_size<T, T2>();
280283
return __sycl_std::__invoke_ldexp<T>(x, k);
281284
}
282285

@@ -291,6 +294,7 @@ template <typename T, typename T2>
291294
detail::enable_if_t<
292295
detail::is_genfloat<T>::value && detail::is_genintptr<T2>::value, T>
293296
lgamma_r(T x, T2 signp) __NOEXC {
297+
detail::check_vector_size<T, T2>();
294298
return __sycl_std::__invoke_lgamma_r<T>(x, signp);
295299
}
296300

@@ -348,6 +352,7 @@ template <typename T, typename T2>
348352
detail::enable_if_t<
349353
detail::is_genfloat<T>::value && detail::is_genfloatptr<T2>::value, T>
350354
modf(T x, T2 iptr) __NOEXC {
355+
detail::check_vector_size<T, T2>();
351356
return __sycl_std::__invoke_modf<T>(x, iptr);
352357
}
353358

@@ -376,6 +381,7 @@ template <typename T, typename T2>
376381
detail::enable_if_t<
377382
detail::is_genfloat<T>::value && detail::is_genint<T2>::value, T>
378383
pown(T x, T2 y) __NOEXC {
384+
detail::check_vector_size<T, T2>();
379385
return __sycl_std::__invoke_pown<T>(x, y);
380386
}
381387

@@ -397,6 +403,7 @@ template <typename T, typename T2>
397403
detail::enable_if_t<
398404
detail::is_genfloat<T>::value && detail::is_genintptr<T2>::value, T>
399405
remquo(T x, T y, T2 quo) __NOEXC {
406+
detail::check_vector_size<T, T2>();
400407
return __sycl_std::__invoke_remquo<T>(x, y, quo);
401408
}
402409

@@ -411,6 +418,7 @@ template <typename T, typename T2>
411418
detail::enable_if_t<
412419
detail::is_genfloat<T>::value && detail::is_genint<T2>::value, T>
413420
rootn(T x, T2 y) __NOEXC {
421+
detail::check_vector_size<T, T2>();
414422
return __sycl_std::__invoke_rootn<T>(x, y);
415423
}
416424

@@ -437,6 +445,7 @@ template <typename T, typename T2>
437445
detail::enable_if_t<
438446
detail::is_genfloat<T>::value && detail::is_genfloatptr<T2>::value, T>
439447
sincos(T x, T2 cosval) __NOEXC {
448+
detail::check_vector_size<T, T2>();
440449
return __sycl_std::__invoke_sincos<T>(x, cosval);
441450
}
442451

@@ -860,6 +869,7 @@ detail::enable_if_t<detail::is_igeninteger8bit<T>::value &&
860869
detail::is_ugeninteger8bit<T2>::value,
861870
detail::make_larger_t<T>>
862871
upsample(T hi, T2 lo) __NOEXC {
872+
detail::check_vector_size<T, T2>();
863873
return __sycl_std::__invoke_s_upsample<detail::make_larger_t<T>>(hi, lo);
864874
}
865875

@@ -877,6 +887,7 @@ detail::enable_if_t<detail::is_igeninteger16bit<T>::value &&
877887
detail::is_ugeninteger16bit<T2>::value,
878888
detail::make_larger_t<T>>
879889
upsample(T hi, T2 lo) __NOEXC {
890+
detail::check_vector_size<T, T2>();
880891
return __sycl_std::__invoke_s_upsample<detail::make_larger_t<T>>(hi, lo);
881892
}
882893

@@ -894,6 +905,7 @@ detail::enable_if_t<detail::is_igeninteger32bit<T>::value &&
894905
detail::is_ugeninteger32bit<T2>::value,
895906
detail::make_larger_t<T>>
896907
upsample(T hi, T2 lo) __NOEXC {
908+
detail::check_vector_size<T, T2>();
897909
return __sycl_std::__invoke_s_upsample<detail::make_larger_t<T>>(hi, lo);
898910
}
899911

@@ -1290,6 +1302,7 @@ template <typename T, typename T2>
12901302
detail::enable_if_t<
12911303
detail::is_geninteger<T>::value && detail::is_igeninteger<T2>::value, T>
12921304
select(T a, T b, T2 c) __NOEXC {
1305+
detail::check_vector_size<T, T2>();
12931306
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
12941307
}
12951308

@@ -1298,6 +1311,7 @@ template <typename T, typename T2>
12981311
detail::enable_if_t<
12991312
detail::is_geninteger<T>::value && detail::is_ugeninteger<T2>::value, T>
13001313
select(T a, T b, T2 c) __NOEXC {
1314+
detail::check_vector_size<T, T2>();
13011315
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13021316
}
13031317

@@ -1306,6 +1320,7 @@ template <typename T, typename T2>
13061320
detail::enable_if_t<
13071321
detail::is_genfloatf<T>::value && detail::is_genint<T2>::value, T>
13081322
select(T a, T b, T2 c) __NOEXC {
1323+
detail::check_vector_size<T, T2>();
13091324
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13101325
}
13111326

@@ -1314,6 +1329,7 @@ template <typename T, typename T2>
13141329
detail::enable_if_t<
13151330
detail::is_genfloatf<T>::value && detail::is_ugenint<T2>::value, T>
13161331
select(T a, T b, T2 c) __NOEXC {
1332+
detail::check_vector_size<T, T2>();
13171333
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13181334
}
13191335

@@ -1322,6 +1338,7 @@ template <typename T, typename T2>
13221338
detail::enable_if_t<
13231339
detail::is_genfloatd<T>::value && detail::is_igeninteger64bit<T2>::value, T>
13241340
select(T a, T b, T2 c) __NOEXC {
1341+
detail::check_vector_size<T, T2>();
13251342
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13261343
}
13271344

@@ -1330,6 +1347,7 @@ template <typename T, typename T2>
13301347
detail::enable_if_t<
13311348
detail::is_genfloatd<T>::value && detail::is_ugeninteger64bit<T2>::value, T>
13321349
select(T a, T b, T2 c) __NOEXC {
1350+
detail::check_vector_size<T, T2>();
13331351
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13341352
}
13351353

@@ -1338,6 +1356,7 @@ template <typename T, typename T2>
13381356
detail::enable_if_t<
13391357
detail::is_genfloath<T>::value && detail::is_igeninteger16bit<T2>::value, T>
13401358
select(T a, T b, T2 c) __NOEXC {
1359+
detail::check_vector_size<T, T2>();
13411360
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13421361
}
13431362

@@ -1346,6 +1365,7 @@ template <typename T, typename T2>
13461365
detail::enable_if_t<
13471366
detail::is_genfloath<T>::value && detail::is_ugeninteger16bit<T2>::value, T>
13481367
select(T a, T b, T2 c) __NOEXC {
1368+
detail::check_vector_size<T, T2>();
13491369
return __sycl_std::__invoke_Select<T>(detail::select_arg_c_t<T2>(c), b, a);
13501370
}
13511371

sycl/include/CL/sycl/detail/generic_type_traits.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,39 @@ template <typename T> static constexpr T quiet_NaN() {
569569
return std::numeric_limits<T>::quiet_NaN();
570570
}
571571

572+
// is_same_vector_size
573+
template <int FirstSize, typename... Args> struct is_same_vector_size_impl;
574+
575+
template <int FirstSize, typename T, typename... Args>
576+
class is_same_vector_size_impl<FirstSize, T, Args...> {
577+
using CurrentT = detail::remove_pointer_t<T>;
578+
static constexpr int Size = vector_size<CurrentT>::value;
579+
static constexpr bool IsSizeEqual = (Size == FirstSize);
580+
581+
public:
582+
static constexpr bool value =
583+
IsSizeEqual ? is_same_vector_size_impl<FirstSize, Args...>::value
584+
: false;
585+
};
586+
587+
template <int FirstSize>
588+
struct is_same_vector_size_impl<FirstSize> : std::true_type {};
589+
590+
template <typename T, typename... Args> class is_same_vector_size {
591+
using CurrentT = remove_pointer_t<T>;
592+
static constexpr int Size = vector_size<CurrentT>::value;
593+
594+
public:
595+
static constexpr bool value = is_same_vector_size_impl<Size, Args...>::value;
596+
};
597+
598+
// check_vector_size
599+
template <typename... Args> inline void check_vector_size() {
600+
static_assert(is_same_vector_size<Args...>::value,
601+
"The built-in function arguments must [point to|have] types "
602+
"with the same number of elements.");
603+
}
604+
572605
} // namespace detail
573606
} // namespace sycl
574607
} // namespace cl

sycl/include/CL/sycl/detail/stl_type_traits.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ using remove_const_t = typename std::remove_const<T>::type;
3535

3636
template <typename T> using remove_cv_t = typename std::remove_cv<T>::type;
3737

38+
template <typename T>
39+
using remove_reference_t = typename std::remove_reference<T>::type;
40+
3841
template <typename T> using add_pointer_t = typename std::add_pointer<T>::type;
3942

4043
template <typename T>

sycl/include/CL/sycl/detail/type_traits.hpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,33 @@ namespace sycl {
2323
template <typename ElementType, access::address_space Space> class multi_ptr;
2424

2525
namespace detail {
26+
template <typename T, typename R> struct copy_cv_qualifiers;
2627

27-
// Contains a type that is the base type for a scalar or vector type
28-
template <typename T> struct get_base_type { using type = T; };
28+
template <typename T, typename R>
29+
using copy_cv_qualifiers_t = typename copy_cv_qualifiers<T, R>::type;
30+
31+
template <int V> using int_constant = std::integral_constant<int, V>;
2932

30-
template <typename T, int N> struct get_base_type<vec<T, N>> {
33+
// vector_size
34+
// scalars are interpreted as a vector of 1 length.
35+
template <typename T> struct vector_size_impl : int_constant<1> {};
36+
template <typename T, int N>
37+
struct vector_size_impl<vec<T, N>> : int_constant<N> {};
38+
template <typename T>
39+
struct vector_size : vector_size_impl<remove_cv_t<remove_reference_t<T>>> {};
40+
41+
// vector_element
42+
template <typename T> struct vector_element_impl;
43+
template <typename T>
44+
using vector_element_impl_t = typename vector_element_impl<T>::type;
45+
template <typename T> struct vector_element_impl { using type = T; };
46+
template <typename T, int N> struct vector_element_impl<vec<T, N>> {
3147
using type = T;
3248
};
33-
34-
template <typename T> using get_base_type_t = typename get_base_type<T>::type;
49+
template <typename T> struct vector_element {
50+
using type = copy_cv_qualifiers_t<T, vector_element_impl_t<remove_cv_t<T>>>;
51+
};
52+
template <class T> using vector_element_t = typename vector_element<T>::type;
3553

3654
// change_base_type_t
3755
template <typename T, typename B> struct change_base_type { using type = B; };
@@ -66,11 +84,7 @@ template <typename T, typename R> struct copy_cv_qualifiers {
6684
using type = typename copy_cv_qualifiers_impl<T, remove_cv_t<R>>::type;
6785
};
6886

69-
template <typename T, typename R>
70-
using copy_cv_qualifiers_t = typename copy_cv_qualifiers<T, R>::type;
71-
7287
// make_signed with support SYCL vec class
73-
7488
template <typename T, typename Enable = void> struct make_signed_impl;
7589

7690
template <typename T>
@@ -85,7 +99,7 @@ struct make_signed_impl<
8599
template <typename T>
86100
struct make_signed_impl<
87101
T, enable_if_t<is_contained<T, gtl::vector_integer_list>::value, T>> {
88-
using base_type = make_signed_impl_t<get_base_type_t<T>>;
102+
using base_type = make_signed_impl_t<vector_element_t<T>>;
89103
using type = change_base_type_t<T, base_type>;
90104
};
91105

@@ -119,7 +133,7 @@ struct make_unsigned_impl<
119133
template <typename T>
120134
struct make_unsigned_impl<
121135
T, enable_if_t<is_contained<T, gtl::vector_integer_list>::value, T>> {
122-
using base_type = make_unsigned_impl_t<get_base_type_t<T>>;
136+
using base_type = make_unsigned_impl_t<vector_element_t<T>>;
123137
using type = change_base_type_t<T, base_type>;
124138
};
125139

@@ -141,11 +155,11 @@ template <typename T> using make_unsigned_t = typename make_unsigned<T>::type;
141155
// Checks that sizeof base type of T equal N and T satisfies S<T>::value
142156
template <typename T, int N, template <typename> class S>
143157
using is_gen_based_on_type_sizeof =
144-
bool_constant<S<T>::value && (sizeof(get_base_type_t<T>) == N)>;
158+
bool_constant<S<T>::value && (sizeof(vector_element_t<T>) == N)>;
145159

146160
// is_integral
147161
template <typename T>
148-
struct is_integral : std::is_integral<get_base_type_t<T>> {};
162+
struct is_integral : std::is_integral<vector_element_t<T>> {};
149163

150164
// is_floating_point
151165
template <typename T>
@@ -155,7 +169,7 @@ template <> struct is_floating_point_impl<half> : std::true_type {};
155169

156170
template <typename T>
157171
struct is_floating_point
158-
: is_floating_point_impl<remove_cv_t<get_base_type_t<T>>> {};
172+
: is_floating_point_impl<remove_cv_t<vector_element_t<T>>> {};
159173

160174
// is_arithmetic
161175
template <typename T>
@@ -264,7 +278,7 @@ struct make_larger_impl<
264278
};
265279

266280
template <typename T, int N> struct make_larger_impl<vec<T, N>, vec<T, N>> {
267-
using base_type = get_base_type_t<vec<T, N>>;
281+
using base_type = vector_element_t<vec<T, N>>;
268282
using upper_type = typename make_larger_impl<base_type, base_type>::type;
269283
using new_type = vec<upper_type, N>;
270284
static constexpr bool found = !std::is_same<upper_type, void>::value;

sycl/test/type_traits/type_traits.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ void test_change_base_type_t() {
4646
}
4747

4848
template <typename T, typename CheckedT, bool Expected = true>
49-
void test_get_base_type_t() {
50-
static_assert(is_same<d::get_base_type_t<T>, CheckedT>::value == Expected,
49+
void test_vector_element_t() {
50+
static_assert(is_same<d::vector_element_t<T>, CheckedT>::value == Expected,
5151
"");
5252
}
5353

5454
template <typename T, typename CheckedT, bool Expected = true>
5555
void test_nan_types() {
56-
static_assert((sizeof(d::get_base_type_t<d::nan_return_t<T>>) ==
56+
static_assert((sizeof(d::vector_element_t<d::nan_return_t<T>>) ==
5757
sizeof(d::nan_argument_base_t<T>)) == Expected,
5858
"");
5959
}
@@ -85,6 +85,15 @@ void test_is_address_space_compliant() {
8585
"");
8686
}
8787

88+
template <typename T, int Checked, bool Expected = true>
89+
void test_vector_size() {
90+
static_assert((d::vector_size<T>::value == Checked) == Expected, "");
91+
}
92+
93+
template <bool Expected, typename... Args> void test_is_same_vector_size() {
94+
static_assert(d::is_same_vector_size<Args...>::value == Expected, "");
95+
}
96+
8897
int main() {
8998
test_is_pointer<int *>();
9099
test_is_pointer<float *>();
@@ -162,8 +171,14 @@ int main() {
162171
test_change_base_type_t<long, float, float>();
163172
test_change_base_type_t<s::long2, float, s::float2>();
164173

165-
test_get_base_type_t<s::int2, int>();
166-
test_get_base_type_t<int, int>();
174+
test_vector_element_t<int, int>();
175+
test_vector_element_t<const int, const int>();
176+
test_vector_element_t<volatile int, volatile int>();
177+
test_vector_element_t<const volatile int, const volatile int>();
178+
test_vector_element_t<s::int2, int>();
179+
test_vector_element_t<const s::int2, const int>();
180+
test_vector_element_t<volatile s::int2, volatile int>();
181+
test_vector_element_t<const volatile s::int2, const volatile int>();
167182

168183
test_nan_types<s::ushort, s::ushort>();
169184
test_nan_types<s::uint, s::uint>();
@@ -192,5 +207,24 @@ int main() {
192207
test_make_unsigned_t<s::uint2, s::uint2>();
193208
test_make_unsigned_t<const s::uint2, const s::uint2>();
194209

210+
test_vector_size<int, 1>();
211+
test_vector_size<float, 1>();
212+
test_vector_size<double, 1>();
213+
test_vector_size<s::int2, 2>();
214+
test_vector_size<s::float3, 3>();
215+
test_vector_size<s::double4, 4>();
216+
test_vector_size<s::vec<int, 1>, 1>();
217+
218+
test_is_same_vector_size<true, int>();
219+
test_is_same_vector_size<true, s::int2>();
220+
test_is_same_vector_size<true, int, float>();
221+
test_is_same_vector_size<false, int, s::float2>();
222+
test_is_same_vector_size<true, s::int2, s::float2>();
223+
test_is_same_vector_size<false, s::int2, float>();
224+
test_is_same_vector_size<true, s::constant_ptr<int>>();
225+
test_is_same_vector_size<true, s::constant_ptr<s::int2>>();
226+
test_is_same_vector_size<true, s::constant_ptr<s::int2>, s::int2>();
227+
test_is_same_vector_size<false, s::constant_ptr<s::int2>, float>();
228+
195229
return 0;
196230
}

0 commit comments

Comments
 (0)