Skip to content

Commit e1119d9

Browse files
authored
[SYCL][ESIMD] Add more compile time checks to rdregion and wrregion API (#13158)
1 parent d1441e8 commit e1119d9

File tree

5 files changed

+103
-36
lines changed

5 files changed

+103
-36
lines changed

sycl/include/sycl/ext/intel/esimd/detail/intrin.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464
//
6565
template <typename T, int N, int M, int VStride, int Width, int Stride,
6666
int ParentWidth = 0>
67-
__ESIMD_INTRIN std::enable_if_t<(Width > 0) && M % Width == 0,
68-
__ESIMD_DNS::vector_type_t<T, M>>
67+
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<T, M>
6968
__esimd_rdregion(__ESIMD_DNS::vector_type_t<T, N> Input, uint16_t Offset);
7069

7170
template <typename T, int N, int M, int ParentWidth = 0>
@@ -121,8 +120,7 @@ __esimd_rdindirect(__ESIMD_DNS::vector_type_t<T, N> Input,
121120
//
122121
template <typename T, int N, int M, int VStride, int Width, int Stride,
123122
int ParentWidth = 0>
124-
__ESIMD_INTRIN std::enable_if_t<M <= N && (Width > 0) && M % Width == 0,
125-
__ESIMD_DNS::vector_type_t<T, N>>
123+
__ESIMD_INTRIN std::enable_if_t<M <= N, __ESIMD_DNS::vector_type_t<T, N>>
126124
__esimd_wrregion(__ESIMD_DNS::vector_type_t<T, N> OldVal,
127125
__ESIMD_DNS::vector_type_t<T, M> NewVal, uint16_t Offset,
128126
__ESIMD_DNS::simd_mask_storage_t<M> Mask = 1);
@@ -142,9 +140,9 @@ template <class T> using __st = __raw_t<T>;
142140

143141
/// read from a basic region of a vector, return a vector
144142
template <typename BT, int BN, typename RTy>
145-
__ESIMD_DNS::vector_type_t<__st<typename RTy::element_type>,
146-
RTy::length> ESIMD_INLINE
147-
readRegion(const __ESIMD_DNS::vector_type_t<__st<BT>, BN> &Base, RTy Region) {
143+
__ESIMD_DNS::vector_type_t<__st<typename RTy::element_type>, RTy::length>
144+
ESIMD_INLINE readRegion(
145+
const __ESIMD_DNS::vector_type_t<__st<BT>, BN> &Base, RTy Region) {
148146
using ElemTy = __st<typename RTy::element_type>;
149147
auto Base1 = bitcast<ElemTy, __st<BT>, BN>(Base);
150148
constexpr int Bytes = BN * sizeof(BT);
@@ -159,6 +157,7 @@ readRegion(const __ESIMD_DNS::vector_type_t<__st<BT>, BN> &Base, RTy Region) {
159157
constexpr int Stride = RTy::Stride_x;
160158
int16_t Offset = static_cast<int16_t>(Region.M_offset_x * sizeof(ElemTy));
161159
// read-region
160+
check_rdregion_params<N, M, /*VS*/ 0, M, Stride>();
162161
return __esimd_rdregion<ElemTy, N, M, /*VS*/ 0, M, Stride>(Base1, Offset);
163162
}
164163
}
@@ -191,7 +190,7 @@ ESIMD_INLINE
191190
constexpr int ParentWidth = PaTy::Size_x;
192191
uint16_t Offset = static_cast<uint16_t>(Region.first.M_offset_y *
193192
PaTy::Size_x * sizeof(ElemTy));
194-
193+
check_rdregion_params<BN1, M, VS, W, HS>();
195194
auto R =
196195
__esimd_rdregion<ElemTy, BN1, M, VS, W, HS, ParentWidth>(Base1, Offset);
197196

@@ -203,6 +202,7 @@ ESIMD_INLINE
203202
constexpr int HS1 = T::Stride_x;
204203
uint16_t Offset1 =
205204
static_cast<uint16_t>(Region.first.M_offset_x * sizeof(ElemTy));
205+
check_rdregion_params<N1, M1, VS1, W1, HS1>();
206206

207207
return __esimd_rdregion<ElemTy, N1, M1, VS1, W1, HS1, ParentWidth>(R,
208208
Offset1);
@@ -263,8 +263,7 @@ __ESIMD_INTRIN uint16_t __esimd_all(__ESIMD_DNS::vector_type_t<T, N> src)
263263
// Implementations of ESIMD intrinsics for the SYCL host device
264264
template <typename T, int N, int M, int VStride, int Width, int Stride,
265265
int ParentWidth>
266-
__ESIMD_INTRIN std::enable_if_t<(Width > 0) && M % Width == 0,
267-
__ESIMD_DNS::vector_type_t<T, M>>
266+
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<T, M>
268267
__esimd_rdregion(__ESIMD_DNS::vector_type_t<T, N> Input, uint16_t Offset) {
269268
uint16_t EltOffset = Offset / sizeof(T);
270269
assert(Offset % sizeof(T) == 0);
@@ -298,8 +297,7 @@ __esimd_rdindirect(__ESIMD_DNS::vector_type_t<T, N> Input,
298297

299298
template <typename T, int N, int M, int VStride, int Width, int Stride,
300299
int ParentWidth>
301-
__ESIMD_INTRIN std::enable_if_t<M <= N && (Width > 0) && M % Width == 0,
302-
__ESIMD_DNS::vector_type_t<T, N>>
300+
__ESIMD_INTRIN std::enable_if_t<M <= N, __ESIMD_DNS::vector_type_t<T, N>>
303301
__esimd_wrregion(__ESIMD_DNS::vector_type_t<T, N> OldVal,
304302
__ESIMD_DNS::vector_type_t<T, M> NewVal, uint16_t Offset,
305303
__ESIMD_DNS::simd_mask_storage_t<M> Mask) {

sycl/include/sycl/ext/intel/esimd/detail/simd_obj_impl.hpp

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ class [[__sycl_detail__::__uses_aspects__(
405405
/// @param Val The object to take new values from.
406406
/// @param Mask The mask.
407407
void merge(const Derived &Val, const simd_mask_type<N> &Mask) {
408+
check_wrregion_params<N, N, 0 /*VS*/, N, 1>();
408409
set(__esimd_wrregion<RawTy, N, N, 0 /*VS*/, N, 1, N>(data(), Val.data(), 0,
409410
Mask.data()));
410411
}
@@ -478,6 +479,7 @@ class [[__sycl_detail__::__uses_aspects__(
478479
static_assert(Size > 1 || Stride == 1,
479480
"Stride must be 1 in single-element region");
480481
Derived &&Val = std::move(cast_this_to_derived());
482+
check_rdregion_params<N, Size, /*VS*/ 0, Size, Stride>();
481483
return __esimd_rdregion<RawTy, N, Size, /*VS*/ 0, Size, Stride>(Val.data(),
482484
Offset);
483485
}
@@ -614,6 +616,7 @@ class [[__sycl_detail__::__uses_aspects__(
614616
template <int Rep, int VS, int W, int HS>
615617
resize_a_simd_type_t<Derived, Rep * W>
616618
replicate_vs_w_hs(uint16_t Offset) const {
619+
check_rdregion_params<N, Rep * W, VS, W, HS>();
617620
return __esimd_rdregion<RawTy, N, Rep * W, VS, W, HS, N>(
618621
data(), Offset * sizeof(RawTy));
619622
}
@@ -655,14 +658,7 @@ class [[__sycl_detail__::__uses_aspects__(
655658
constexpr int M = RTy::Size_x;
656659
constexpr int Stride = RTy::Stride_x;
657660
uint16_t Offset = Region.M_offset_x * sizeof(ElemTy);
658-
static_assert(M > 0, "Malformed RHS region.");
659-
static_assert(M <= BN, "Attempt to write beyond viewed area: The viewed "
660-
"object in LHS does not fit RHS.");
661-
// (M > BN) condition is added below to not duplicate the above assert
662-
// for big values of M. The assert below is for 'Stride'.
663-
static_assert((M > BN) || (M - 1) * Stride < BN,
664-
"Malformed RHS region - too big stride.");
665-
661+
check_wrregion_params<BN, M, /*VS*/ 0, M, Stride>();
666662
// Merge and update.
667663
auto Merged = __esimd_wrregion<ElemTy, BN, M,
668664
/*VS*/ 0, M, Stride>(Base, Val, Offset);
@@ -697,11 +693,7 @@ class [[__sycl_detail__::__uses_aspects__(
697693
constexpr int Stride = TR::Stride_x;
698694
uint16_t Offset = Region.first.M_offset_x * sizeof(ElemTy);
699695

700-
static_assert(M <= BN1, "Attempt to write beyond viewed area: The "
701-
"viewed object in LHS does not fit RHS.");
702-
static_assert(M > 0, "Malformed RHS region.");
703-
static_assert((M - 1) * Stride < BN,
704-
"Malformed RHS region - too big stride.");
696+
check_wrregion_params<BN1, M, /*VS*/ 0, M, Stride>();
705697
// Merge and update.
706698
Base1 = __esimd_wrregion<ElemTy, BN1, M,
707699
/*VS*/ 0, M, Stride>(Base1, Val, Offset);
@@ -719,12 +711,7 @@ class [[__sycl_detail__::__uses_aspects__(
719711
(Region.first.M_offset_y * PaTy::Size_x + Region.first.M_offset_x) *
720712
sizeof(ElemTy));
721713

722-
static_assert(M <= BN1, "Attempt to write beyond viewed area: The "
723-
"viewed object in LHS does not fit RHS.");
724-
static_assert(M > 0 && W > 0 && M % W == 0, "Malformed RHS region.");
725-
static_assert(W == 0 || ((M / W) - 1) * VS + (W - 1) * HS < BN1,
726-
"Malformed RHS region - too big vertical and/or "
727-
"horizontal stride.");
714+
check_wrregion_params<BN1, M, VS, W, HS>();
728715
// Merge and update.
729716
Base1 = __esimd_wrregion<ElemTy, BN1, M, VS, W, HS, ParentWidth>(
730717
Base1, Val, Offset);

sycl/include/sycl/ext/intel/esimd/detail/util.hpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,76 @@ auto accessorToPointer(AccessorTy Acc, OffsetTy Offset = 0) {
209209
}
210210
#endif // __ESIMD_FORCE_STATELESS_MEM
211211

212+
/// @brief Checks parameters for read region intrinsic API. The checks were
213+
/// refactored from simd_obj API.
214+
/// @tparam N the input vector size.
215+
/// @tparam M the return vector size.
216+
/// @tparam VStride the vertical stride in elements between rows.
217+
/// @tparam Width the size or each row, non-zero and even divides `M`.
218+
/// @tparam Stride horizontal stride in elements within each row.
219+
// The rdregion intrinsics computes a result vector using following algorithm:
220+
//
221+
// \code{.cpp}
222+
// uint16_t EltOffset = Offset / sizeof(T);
223+
// assert(Offset % sizeof(T) == 0);
224+
//
225+
// int NumRows = M / Width;
226+
// assert(M % Width == 0);
227+
//
228+
// int Index = 0;
229+
// for (int i = 0; i < NumRows; ++i) {
230+
// for (int j = 0; j < Width; ++j) {
231+
// Result[Index++] = Input[i * VStride + j * Stride +
232+
// EltOffset];
233+
// }
234+
// }
235+
// \endcode
236+
// Hence the checks are to prevent reading beyond the input vector.
237+
template <int N, int M, int VStride, int Width, int Stride>
238+
constexpr void check_rdregion_params() {
239+
static_assert(Width > 0 && M % Width == 0, "Malformed RHS region.");
240+
static_assert(Width == M ||
241+
((M / Width) - 1) * VStride + (Width - 1) * Stride < N,
242+
"Malformed RHS region - too big vertical and/or "
243+
"horizontal stride.");
244+
}
245+
246+
/// @brief Checks parameters for write region intrinsic API. The checks were
247+
/// refactored from simd_obj API.
248+
/// @tparam N the input vector size.
249+
/// @tparam M the return vector size.
250+
/// @tparam VStride the vertical stride in elements between rows.
251+
/// @tparam Width the size or each row, non-zero and even divides `M`.
252+
/// @tparam Stride horizontal stride in elements within each row.
253+
// The wrregion intrinsics computes a result vector using following algorithm:
254+
//
255+
// \code{.cpp}
256+
// uint16_t EltOffset = Offset / sizeof(T);
257+
// assert(Offset % sizeof(T) == 0);
258+
//
259+
// int NumRows = M / Width;
260+
// assert(M % Width == 0);
261+
//
262+
// Result = OldValue;
263+
// int Index = 0;
264+
// for (int i = 0; i < NumRows; ++i) {
265+
// for (int j = 0; j < Width; ++j) {
266+
// if (Mask[Index])
267+
// Result[i * VStride + j * Stride + EltOffset] = NewVal[Index];
268+
// ++Index;
269+
// }
270+
// }
271+
// \endcode
272+
// Hence the checks are to prevent reading beyond the input array and prevent
273+
// writing beyond destination vector.
274+
template <int N, int M, int VStride, int Width, int Stride>
275+
constexpr void check_wrregion_params() {
276+
static_assert(M <= N, "Attempt to access beyond viewed area: The "
277+
"viewed object in LHS does not fit RHS.");
278+
static_assert((M - 1) * Stride < N, "Malformed RHS region - too big stride.");
279+
check_rdregion_params<N, M, VStride, Width, Stride>();
280+
}
281+
212282
} // namespace ext::intel::esimd::detail
213283
} // namespace _V1
214284
} // namespace sycl

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,11 +365,13 @@ gather(const T *p, simd<OffsetT, N / VS> byte_offsets, simd_mask<N / VS> mask,
365365
auto Ret = __esimd_svm_gather<MsgT, N, detail::ElemsPerAddrEncoding<4>(),
366366
detail::ElemsPerAddrEncoding<1>()>(
367367
Addrs.data(), mask.data());
368+
detail::check_rdregion_params<N * 4, N, /*VS*/ 0, N, 4>();
368369
return __esimd_rdregion<MsgT, N * 4, N, /*VS*/ 0, N, 4>(Ret, 0);
369370
} else if constexpr (sizeof(T) == 2) {
370371
auto Ret = __esimd_svm_gather<MsgT, N, detail::ElemsPerAddrEncoding<2>(),
371372
detail::ElemsPerAddrEncoding<2>()>(
372373
Addrs.data(), mask.data());
374+
detail::check_rdregion_params<N * 2, N, /*VS*/ 0, N, 2>();
373375
return __esimd_rdregion<MsgT, N * 2, N, /*VS*/ 0, N, 2>(Ret, 0);
374376
} else {
375377
return __esimd_svm_gather<MsgT, N, detail::ElemsPerAddrEncoding<1>(),
@@ -703,12 +705,14 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
703705
simd<uint64_t, N> addrs(reinterpret_cast<uint64_t>(p));
704706
addrs = addrs + byte_offsets_i;
705707
if constexpr (sizeof(T) == 1) {
708+
detail::check_wrregion_params<N * 4, N, /*VS*/ 0, N, 4>();
706709
simd<T, N * 4> D = __esimd_wrregion<Tx, N * 4, N, /*VS*/ 0, N, 4>(
707710
D.data(), vals.data(), 0);
708711
__esimd_svm_scatter<Tx, N, detail::ElemsPerAddrEncoding<4>(),
709712
detail::ElemsPerAddrEncoding<1>()>(
710713
addrs.data(), D.data(), mask.data());
711714
} else if constexpr (sizeof(T) == 2) {
715+
detail::check_wrregion_params<N * 2, N, /*VS*/ 0, N, 2>();
712716
simd<Tx, N * 2> D = __esimd_wrregion<Tx, N * 2, N, /*VS*/ 0, N, 2>(
713717
D.data(), vals.data(), 0);
714718
__esimd_svm_scatter<Tx, N, detail::ElemsPerAddrEncoding<2>(),

sycl/test/esimd/wrregion.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,27 @@ SYCL_ESIMD_FUNCTION void test_wrregion_size_check() {
1111
simd<int, 16> v16 = 0;
1212
simd<int, 64> v64;
1313
v16.template select<64, 1>(0) = v64;
14-
// expected-error@* {{static assertion failed due to requirement 'M <= BN'}}
14+
// expected-error@sycl/ext/intel/esimd/detail/util.hpp:* {{static assertion failed due to requirement '64 <= 16'}}
15+
// expected-note@sycl/ext/intel/esimd/detail/simd_obj_impl.hpp:* {{in instantiation of function template specialization}}
1516
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of function template specialization}}
1617
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of member function}}
17-
// [email protected]:* {{in instantiation of member function}}
18-
// expected-note@sycl/ext/intel/esimd/detail/simd_obj_impl.hpp:* {{expression evaluates to '64 <= 16'}}
18+
// expected-note@* {{in instantiation of member function}}
19+
20+
// expected-error@sycl/ext/intel/esimd/detail/util.hpp:* {{static assertion failed due to requirement '(64 - 1) * 1 < 16'}}
21+
// expected-note@* {{expression evaluates to '63 < 16'}}
1922

2023
// expected-error@* {{no matching function for call to '__esimd_wrregion'}}
24+
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of function template specialization}}
25+
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of member function}}
26+
// expected-note@* {{in instantiation of member function}}
2127
// expected-note@sycl/ext/intel/esimd/detail/intrin.hpp:* {{candidate template ignored: requirement '64 <= 16' was not satisfied}}
2228

2329
simd<int, 2> v2;
2430
v16.template select<2, 64>() = v2;
25-
// expected-error@* {{static assertion failed due to requirement '(M > BN) || (M - 1) * Stride < BN'}}
31+
// expected-error@sycl/ext/intel/esimd/detail/util.hpp:* {{static assertion failed due to requirement '(2 - 1) * 64 < 16'}}
32+
// expected-note@sycl/ext/intel/esimd/detail/simd_obj_impl.hpp:* {{in instantiation of function template specialization}}
2633
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of function template specialization}}
2734
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of member function}}
28-
// [email protected]:* {{in instantiation of member function}}
35+
// expected-note@* {{in instantiation of member function}}
36+
// expected-note@* {{expression evaluates to '64 < 16'}}
2937
}

0 commit comments

Comments
 (0)