Skip to content

[SYCL][ESIMD] Add more compile time checks to rdregion and wrregion API #13158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 25, 2024
22 changes: 10 additions & 12 deletions sycl/include/sycl/ext/intel/esimd/detail/intrin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@
//
template <typename T, int N, int M, int VStride, int Width, int Stride,
int ParentWidth = 0>
__ESIMD_INTRIN std::enable_if_t<(Width > 0) && M % Width == 0,
__ESIMD_DNS::vector_type_t<T, M>>
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<T, M>
__esimd_rdregion(__ESIMD_DNS::vector_type_t<T, N> Input, uint16_t Offset);

template <typename T, int N, int M, int ParentWidth = 0>
Expand Down Expand Up @@ -121,8 +120,7 @@ __esimd_rdindirect(__ESIMD_DNS::vector_type_t<T, N> Input,
//
template <typename T, int N, int M, int VStride, int Width, int Stride,
int ParentWidth = 0>
__ESIMD_INTRIN std::enable_if_t<M <= N && (Width > 0) && M % Width == 0,
__ESIMD_DNS::vector_type_t<T, N>>
__ESIMD_INTRIN std::enable_if_t<M <= N, __ESIMD_DNS::vector_type_t<T, N>>
__esimd_wrregion(__ESIMD_DNS::vector_type_t<T, N> OldVal,
__ESIMD_DNS::vector_type_t<T, M> NewVal, uint16_t Offset,
__ESIMD_DNS::simd_mask_storage_t<M> Mask = 1);
Expand All @@ -142,9 +140,9 @@ template <class T> using __st = __raw_t<T>;

/// read from a basic region of a vector, return a vector
template <typename BT, int BN, typename RTy>
__ESIMD_DNS::vector_type_t<__st<typename RTy::element_type>,
RTy::length> ESIMD_INLINE
readRegion(const __ESIMD_DNS::vector_type_t<__st<BT>, BN> &Base, RTy Region) {
__ESIMD_DNS::vector_type_t<__st<typename RTy::element_type>, RTy::length>
ESIMD_INLINE readRegion(
const __ESIMD_DNS::vector_type_t<__st<BT>, BN> &Base, RTy Region) {
using ElemTy = __st<typename RTy::element_type>;
auto Base1 = bitcast<ElemTy, __st<BT>, BN>(Base);
constexpr int Bytes = BN * sizeof(BT);
Expand All @@ -159,6 +157,7 @@ readRegion(const __ESIMD_DNS::vector_type_t<__st<BT>, BN> &Base, RTy Region) {
constexpr int Stride = RTy::Stride_x;
int16_t Offset = static_cast<int16_t>(Region.M_offset_x * sizeof(ElemTy));
// read-region
check_rdregion_params<N, M, /*VS*/ 0, M, Stride>();
return __esimd_rdregion<ElemTy, N, M, /*VS*/ 0, M, Stride>(Base1, Offset);
}
}
Expand Down Expand Up @@ -191,7 +190,7 @@ ESIMD_INLINE
constexpr int ParentWidth = PaTy::Size_x;
uint16_t Offset = static_cast<uint16_t>(Region.first.M_offset_y *
PaTy::Size_x * sizeof(ElemTy));

check_rdregion_params<BN1, M, VS, W, HS>();
auto R =
__esimd_rdregion<ElemTy, BN1, M, VS, W, HS, ParentWidth>(Base1, Offset);

Expand All @@ -203,6 +202,7 @@ ESIMD_INLINE
constexpr int HS1 = T::Stride_x;
uint16_t Offset1 =
static_cast<uint16_t>(Region.first.M_offset_x * sizeof(ElemTy));
check_rdregion_params<N1, M1, VS1, W1, HS1>();

return __esimd_rdregion<ElemTy, N1, M1, VS1, W1, HS1, ParentWidth>(R,
Offset1);
Expand Down Expand Up @@ -263,8 +263,7 @@ __ESIMD_INTRIN uint16_t __esimd_all(__ESIMD_DNS::vector_type_t<T, N> src)
// Implementations of ESIMD intrinsics for the SYCL host device
template <typename T, int N, int M, int VStride, int Width, int Stride,
int ParentWidth>
__ESIMD_INTRIN std::enable_if_t<(Width > 0) && M % Width == 0,
__ESIMD_DNS::vector_type_t<T, M>>
__ESIMD_INTRIN __ESIMD_DNS::vector_type_t<T, M>
__esimd_rdregion(__ESIMD_DNS::vector_type_t<T, N> Input, uint16_t Offset) {
uint16_t EltOffset = Offset / sizeof(T);
assert(Offset % sizeof(T) == 0);
Expand Down Expand Up @@ -298,8 +297,7 @@ __esimd_rdindirect(__ESIMD_DNS::vector_type_t<T, N> Input,

template <typename T, int N, int M, int VStride, int Width, int Stride,
int ParentWidth>
__ESIMD_INTRIN std::enable_if_t<M <= N && (Width > 0) && M % Width == 0,
__ESIMD_DNS::vector_type_t<T, N>>
__ESIMD_INTRIN std::enable_if_t<M <= N, __ESIMD_DNS::vector_type_t<T, N>>
__esimd_wrregion(__ESIMD_DNS::vector_type_t<T, N> OldVal,
__ESIMD_DNS::vector_type_t<T, M> NewVal, uint16_t Offset,
__ESIMD_DNS::simd_mask_storage_t<M> Mask) {
Expand Down
25 changes: 6 additions & 19 deletions sycl/include/sycl/ext/intel/esimd/detail/simd_obj_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ class [[__sycl_detail__::__uses_aspects__(
/// @param Val The object to take new values from.
/// @param Mask The mask.
void merge(const Derived &Val, const simd_mask_type<N> &Mask) {
check_wrregion_params<N, N, 0 /*VS*/, N, 1>();
set(__esimd_wrregion<RawTy, N, N, 0 /*VS*/, N, 1, N>(data(), Val.data(), 0,
Mask.data()));
}
Expand Down Expand Up @@ -478,6 +479,7 @@ class [[__sycl_detail__::__uses_aspects__(
static_assert(Size > 1 || Stride == 1,
"Stride must be 1 in single-element region");
Derived &&Val = std::move(cast_this_to_derived());
check_rdregion_params<N, Size, /*VS*/ 0, Size, Stride>();
return __esimd_rdregion<RawTy, N, Size, /*VS*/ 0, Size, Stride>(Val.data(),
Offset);
}
Expand Down Expand Up @@ -614,6 +616,7 @@ class [[__sycl_detail__::__uses_aspects__(
template <int Rep, int VS, int W, int HS>
resize_a_simd_type_t<Derived, Rep * W>
replicate_vs_w_hs(uint16_t Offset) const {
check_rdregion_params<N, Rep * W, VS, W, HS>();
return __esimd_rdregion<RawTy, N, Rep * W, VS, W, HS, N>(
data(), Offset * sizeof(RawTy));
}
Expand Down Expand Up @@ -655,14 +658,7 @@ class [[__sycl_detail__::__uses_aspects__(
constexpr int M = RTy::Size_x;
constexpr int Stride = RTy::Stride_x;
uint16_t Offset = Region.M_offset_x * sizeof(ElemTy);
static_assert(M > 0, "Malformed RHS region.");
static_assert(M <= BN, "Attempt to write beyond viewed area: The viewed "
"object in LHS does not fit RHS.");
// (M > BN) condition is added below to not duplicate the above assert
// for big values of M. The assert below is for 'Stride'.
static_assert((M > BN) || (M - 1) * Stride < BN,
"Malformed RHS region - too big stride.");

check_wrregion_params<BN, M, /*VS*/ 0, M, Stride>();
// Merge and update.
auto Merged = __esimd_wrregion<ElemTy, BN, M,
/*VS*/ 0, M, Stride>(Base, Val, Offset);
Expand Down Expand Up @@ -697,11 +693,7 @@ class [[__sycl_detail__::__uses_aspects__(
constexpr int Stride = TR::Stride_x;
uint16_t Offset = Region.first.M_offset_x * sizeof(ElemTy);

static_assert(M <= BN1, "Attempt to write beyond viewed area: The "
"viewed object in LHS does not fit RHS.");
static_assert(M > 0, "Malformed RHS region.");
static_assert((M - 1) * Stride < BN,
"Malformed RHS region - too big stride.");
check_wrregion_params<BN1, M, /*VS*/ 0, M, Stride>();
// Merge and update.
Base1 = __esimd_wrregion<ElemTy, BN1, M,
/*VS*/ 0, M, Stride>(Base1, Val, Offset);
Expand All @@ -719,12 +711,7 @@ class [[__sycl_detail__::__uses_aspects__(
(Region.first.M_offset_y * PaTy::Size_x + Region.first.M_offset_x) *
sizeof(ElemTy));

static_assert(M <= BN1, "Attempt to write beyond viewed area: The "
"viewed object in LHS does not fit RHS.");
static_assert(M > 0 && W > 0 && M % W == 0, "Malformed RHS region.");
static_assert(W == 0 || ((M / W) - 1) * VS + (W - 1) * HS < BN1,
"Malformed RHS region - too big vertical and/or "
"horizontal stride.");
check_wrregion_params<BN1, M, VS, W, HS>();
// Merge and update.
Base1 = __esimd_wrregion<ElemTy, BN1, M, VS, W, HS, ParentWidth>(
Base1, Val, Offset);
Expand Down
70 changes: 70 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/detail/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,76 @@ auto accessorToPointer(AccessorTy Acc, OffsetTy Offset = 0) {
}
#endif // __ESIMD_FORCE_STATELESS_MEM

/// @brief Checks parameters for read region intrinsic API. The checks were
/// refactored from simd_obj API.
/// @tparam N the input vector size.
/// @tparam M the return vector size.
/// @tparam VStride the vertical stride in elements between rows.
/// @tparam Width the size or each row, non-zero and even divides `M`.
/// @tparam Stride horizontal stride in elements within each row.
// The rdregion intrinsics computes a result vector using following algorithm:
//
// \code{.cpp}
// uint16_t EltOffset = Offset / sizeof(T);
// assert(Offset % sizeof(T) == 0);
//
// int NumRows = M / Width;
// assert(M % Width == 0);
//
// int Index = 0;
// for (int i = 0; i < NumRows; ++i) {
// for (int j = 0; j < Width; ++j) {
// Result[Index++] = Input[i * VStride + j * Stride +
// EltOffset];
// }
// }
// \endcode
// Hence the checks are to prevent reading beyond the input vector.
template <int N, int M, int VStride, int Width, int Stride>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write a detailed comment to these 2 functions and their params. Otherwise, it is really difficult (or impossible) to understand what they verify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. It is still difficult to understand why the constraint is such: Width == M || ((M / Width) - 1) * VStride + (Width - 1) * Stride < N. It would help a lot if the pseudo-code would be copied from __esimd_rdregion/wrregion to the commend of these check functions, e.g.;
// This intrinsic computes a vector Result:
//
// \code{.cpp}
// uint16_t EltOffset = Offset / sizeof(T);
// assert(Offset % sizeof(T) == 0);
//
// int NumRows = M / Width;
// assert(M % Width == 0);
//
// int Index = 0;
// for (int i = 0; i < NumRows; ++i) {
// for (int j = 0; j < Width; ++j) {
// Result[Index++] = Input[i * VStride + j * Stride +
// EltOffset];
// }
// }
// \endcode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do the same for __esimd_wrregion please?
Perhaps, fix the comment in both places too - comment does not say that the write is skipped if mask is 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, for confusion. the pseudo-code for __esimd_wrregion is correct. Please just copy it to check function here.
The __esimd_rdregion does not have a mask operand, thus is is all right in the pseudo-code for rdregion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

constexpr void check_rdregion_params() {
static_assert(Width > 0 && M % Width == 0, "Malformed RHS region.");
static_assert(Width == M ||
((M / Width) - 1) * VStride + (Width - 1) * Stride < N,
"Malformed RHS region - too big vertical and/or "
"horizontal stride.");
}

/// @brief Checks parameters for write region intrinsic API. The checks were
/// refactored from simd_obj API.
/// @tparam N the input vector size.
/// @tparam M the return vector size.
/// @tparam VStride the vertical stride in elements between rows.
/// @tparam Width the size or each row, non-zero and even divides `M`.
/// @tparam Stride horizontal stride in elements within each row.
// The wrregion intrinsics computes a result vector using following algorithm:
//
// \code{.cpp}
// uint16_t EltOffset = Offset / sizeof(T);
// assert(Offset % sizeof(T) == 0);
//
// int NumRows = M / Width;
// assert(M % Width == 0);
//
// Result = OldValue;
// int Index = 0;
// for (int i = 0; i < NumRows; ++i) {
// for (int j = 0; j < Width; ++j) {
// if (Mask[Index])
// Result[i * VStride + j * Stride + EltOffset] = NewVal[Index];
// ++Index;
// }
// }
// \endcode
// Hence the checks are to prevent reading beyond the input array and prevent
// writing beyond destination vector.
template <int N, int M, int VStride, int Width, int Stride>
constexpr void check_wrregion_params() {
static_assert(M <= N, "Attempt to access beyond viewed area: The "
"viewed object in LHS does not fit RHS.");
static_assert((M - 1) * Stride < N, "Malformed RHS region - too big stride.");
check_rdregion_params<N, M, VStride, Width, Stride>();
}

} // namespace ext::intel::esimd::detail
} // namespace _V1
} // namespace sycl
Expand Down
4 changes: 4 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,13 @@ gather(const T *p, simd<OffsetT, N / VS> byte_offsets, simd_mask<N / VS> mask,
auto Ret = __esimd_svm_gather<MsgT, N, detail::ElemsPerAddrEncoding<4>(),
detail::ElemsPerAddrEncoding<1>()>(
Addrs.data(), mask.data());
detail::check_rdregion_params<N * 4, N, /*VS*/ 0, N, 4>();
return __esimd_rdregion<MsgT, N * 4, N, /*VS*/ 0, N, 4>(Ret, 0);
} else if constexpr (sizeof(T) == 2) {
auto Ret = __esimd_svm_gather<MsgT, N, detail::ElemsPerAddrEncoding<2>(),
detail::ElemsPerAddrEncoding<2>()>(
Addrs.data(), mask.data());
detail::check_rdregion_params<N * 2, N, /*VS*/ 0, N, 2>();
return __esimd_rdregion<MsgT, N * 2, N, /*VS*/ 0, N, 2>(Ret, 0);
} else {
return __esimd_svm_gather<MsgT, N, detail::ElemsPerAddrEncoding<1>(),
Expand Down Expand Up @@ -703,12 +705,14 @@ scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
simd<uint64_t, N> addrs(reinterpret_cast<uint64_t>(p));
addrs = addrs + byte_offsets_i;
if constexpr (sizeof(T) == 1) {
detail::check_wrregion_params<N * 4, N, /*VS*/ 0, N, 4>();
simd<T, N * 4> D = __esimd_wrregion<Tx, N * 4, N, /*VS*/ 0, N, 4>(
D.data(), vals.data(), 0);
__esimd_svm_scatter<Tx, N, detail::ElemsPerAddrEncoding<4>(),
detail::ElemsPerAddrEncoding<1>()>(
addrs.data(), D.data(), mask.data());
} else if constexpr (sizeof(T) == 2) {
detail::check_wrregion_params<N * 2, N, /*VS*/ 0, N, 2>();
simd<Tx, N * 2> D = __esimd_wrregion<Tx, N * 2, N, /*VS*/ 0, N, 2>(
D.data(), vals.data(), 0);
__esimd_svm_scatter<Tx, N, detail::ElemsPerAddrEncoding<2>(),
Expand Down
18 changes: 13 additions & 5 deletions sycl/test/esimd/wrregion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,27 @@ SYCL_ESIMD_FUNCTION void test_wrregion_size_check() {
simd<int, 16> v16 = 0;
simd<int, 64> v64;
v16.template select<64, 1>(0) = v64;
// expected-error@* {{static assertion failed due to requirement 'M <= BN'}}
// expected-error@sycl/ext/intel/esimd/detail/util.hpp:* {{static assertion failed due to requirement '64 <= 16'}}
// expected-note@sycl/ext/intel/esimd/detail/simd_obj_impl.hpp:* {{in instantiation of function template specialization}}
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of function template specialization}}
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of member function}}
// [email protected]:* {{in instantiation of member function}}
// expected-note@sycl/ext/intel/esimd/detail/simd_obj_impl.hpp:* {{expression evaluates to '64 <= 16'}}
// expected-note@* {{in instantiation of member function}}

// expected-error@sycl/ext/intel/esimd/detail/util.hpp:* {{static assertion failed due to requirement '(64 - 1) * 1 < 16'}}
// expected-note@* {{expression evaluates to '63 < 16'}}

// expected-error@* {{no matching function for call to '__esimd_wrregion'}}
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of function template specialization}}
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of member function}}
// expected-note@* {{in instantiation of member function}}
// expected-note@sycl/ext/intel/esimd/detail/intrin.hpp:* {{candidate template ignored: requirement '64 <= 16' was not satisfied}}

simd<int, 2> v2;
v16.template select<2, 64>() = v2;
// expected-error@* {{static assertion failed due to requirement '(M > BN) || (M - 1) * Stride < BN'}}
// expected-error@sycl/ext/intel/esimd/detail/util.hpp:* {{static assertion failed due to requirement '(2 - 1) * 64 < 16'}}
// expected-note@sycl/ext/intel/esimd/detail/simd_obj_impl.hpp:* {{in instantiation of function template specialization}}
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of function template specialization}}
// expected-note@sycl/ext/intel/esimd/detail/simd_view_impl.hpp:* {{in instantiation of member function}}
// [email protected]:* {{in instantiation of member function}}
// expected-note@* {{in instantiation of member function}}
// expected-note@* {{expression evaluates to '64 < 16'}}
}