Skip to content

Commit 8d7396d

Browse files
authored
[SYCL][ESIMD] Add more stringent compile time checks to local_accessor version of block_load/block_store, gather/scatter API (#11653)
1 parent 331e513 commit 8d7396d

File tree

3 files changed

+58
-24
lines changed

3 files changed

+58
-24
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3057,11 +3057,12 @@ __ESIMD_API void media_block_store(AccessorTy acc, unsigned x, unsigned y,
30573057
///
30583058
template <typename Tx, int N, typename AccessorTy,
30593059
typename Flags = overaligned_tag<detail::OperandSize::OWORD>>
3060-
__ESIMD_API std::enable_if_t<
3061-
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy> &&
3062-
is_simd_flag_type_v<Flags>,
3063-
simd<Tx, N>>
3064-
block_load(AccessorTy acc, uint32_t offset, Flags = {}) {
3060+
__ESIMD_API
3061+
std::enable_if_t<detail::is_local_accessor_with_v<
3062+
AccessorTy, detail::accessor_mode_cap::can_read> &&
3063+
is_simd_flag_type_v<Flags>,
3064+
simd<Tx, N>>
3065+
block_load(AccessorTy acc, uint32_t offset, Flags = {}) {
30653066
return slm_block_load<Tx, N, Flags>(offset +
30663067
__ESIMD_DNS::localAccessorToOffset(acc));
30673068
}
@@ -3085,10 +3086,11 @@ block_load(AccessorTy acc, uint32_t offset, Flags = {}) {
30853086
///
30863087
template <typename Tx, int N, typename AccessorTy,
30873088
typename Flags = overaligned_tag<detail::OperandSize::OWORD>>
3088-
__ESIMD_API std::enable_if_t<
3089-
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy> &&
3090-
is_simd_flag_type_v<Flags>>
3091-
block_store(AccessorTy acc, uint32_t offset, simd<Tx, N> vals, Flags = {}) {
3089+
__ESIMD_API
3090+
std::enable_if_t<detail::is_local_accessor_with_v<
3091+
AccessorTy, detail::accessor_mode_cap::can_write> &&
3092+
is_simd_flag_type_v<Flags>>
3093+
block_store(AccessorTy acc, uint32_t offset, simd<Tx, N> vals, Flags = {}) {
30923094
slm_block_store<Tx, N, Flags>(
30933095
offset + __ESIMD_DNS::localAccessorToOffset(acc), vals);
30943096
}
@@ -3111,10 +3113,12 @@ block_store(AccessorTy acc, uint32_t offset, simd<Tx, N> vals, Flags = {}) {
31113113
/// undefined.
31123114
///
31133115
template <typename T, int N, typename AccessorTy>
3114-
__ESIMD_API std::enable_if_t<
3115-
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy>, simd<T, N>>
3116-
gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
3117-
simd_mask<N> mask = 1) {
3116+
__ESIMD_API
3117+
std::enable_if_t<detail::is_local_accessor_with_v<
3118+
AccessorTy, detail::accessor_mode_cap::can_read>,
3119+
simd<T, N>>
3120+
gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
3121+
simd_mask<N> mask = 1) {
31183122
return slm_gather<T, N>(
31193123
offsets + glob_offset + __ESIMD_DNS::localAccessorToOffset(acc), mask);
31203124
}
@@ -3138,8 +3142,8 @@ gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
31383142
///
31393143
///
31403144
template <typename T, int N, typename AccessorTy>
3141-
__ESIMD_API std::enable_if_t<
3142-
sycl::detail::acc_properties::is_local_accessor_v<AccessorTy>>
3145+
__ESIMD_API std::enable_if_t<detail::is_local_accessor_with_v<
3146+
AccessorTy, detail::accessor_mode_cap::can_write>>
31433147
scatter(AccessorTy acc, simd<uint32_t, N> offsets, simd<T, N> vals,
31443148
uint32_t glob_offset = 0, simd_mask<N> mask = 1) {
31453149
slm_scatter<T, N>(offsets + glob_offset +
@@ -3174,11 +3178,12 @@ scatter(AccessorTy acc, simd<uint32_t, N> offsets, simd<T, N> vals,
31743178
template <rgba_channel_mask RGBAMask = rgba_channel_mask::ABGR,
31753179
typename AccessorT, int N,
31763180
typename T = typename AccessorT::value_type>
3177-
__ESIMD_API std::enable_if_t<
3178-
sycl::detail::acc_properties::is_local_accessor_v<AccessorT>,
3179-
simd<T, N * get_num_channels_enabled(RGBAMask)>>
3180-
gather_rgba(AccessorT acc, simd<uint32_t, N> offsets,
3181-
uint32_t global_offset = 0, simd_mask<N> mask = 1) {
3181+
__ESIMD_API
3182+
std::enable_if_t<detail::is_local_accessor_with_v<
3183+
AccessorT, detail::accessor_mode_cap::can_read>,
3184+
simd<T, N * get_num_channels_enabled(RGBAMask)>>
3185+
gather_rgba(AccessorT acc, simd<uint32_t, N> offsets,
3186+
uint32_t global_offset = 0, simd_mask<N> mask = 1) {
31823187
return slm_gather_rgba<T, N, RGBAMask>(
31833188
offsets + global_offset + __ESIMD_DNS::localAccessorToOffset(acc), mask);
31843189
}
@@ -3202,8 +3207,8 @@ gather_rgba(AccessorT acc, simd<uint32_t, N> offsets,
32023207
template <rgba_channel_mask RGBAMask = rgba_channel_mask::ABGR,
32033208
typename AccessorT, int N,
32043209
typename T = typename AccessorT::value_type>
3205-
__ESIMD_API std::enable_if_t<
3206-
sycl::detail::acc_properties::is_local_accessor_v<AccessorT>>
3210+
__ESIMD_API std::enable_if_t<detail::is_local_accessor_with_v<
3211+
AccessorT, detail::accessor_mode_cap::can_write>>
32073212
scatter_rgba(AccessorT acc, simd<uint32_t, N> offsets,
32083213
simd<T, N * get_num_channels_enabled(RGBAMask)> vals,
32093214
uint32_t global_offset = 0, simd_mask<N> mask = 1) {

sycl/test/esimd/block_load_store.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ SYCL_EXTERNAL void kernel2(int *ptr) SYCL_ESIMD_FUNCTION {
3030

3131
// Incompatible mode (write).
3232
SYCL_EXTERNAL void
33-
kernel3(accessor<int, 1, access::mode::write, access::target::device> &buf)
33+
kernel4(accessor<int, 1, access::mode::write, access::target::device> &buf)
3434
SYCL_ESIMD_FUNCTION {
3535
simd<int, 32> v;
3636
// CHECK: block_load_store.cpp:38{{.*}}error: no matching function
@@ -40,10 +40,19 @@ kernel3(accessor<int, 1, access::mode::write, access::target::device> &buf)
4040

4141
// Incompatible mode (read).
4242
SYCL_EXTERNAL void
43-
kernel4(accessor<int, 1, access::mode::read, access::target::device> &buf)
43+
kernel5(accessor<int, 1, access::mode::read, access::target::device> &buf)
4444
SYCL_ESIMD_FUNCTION {
4545
simd<int, 32> v(0, 1);
4646
// CHECK: block_load_store.cpp:48{{.*}}error: no matching function
4747
// function for call to 'block_store'
4848
block_store<int, 32>(buf, 0, v);
4949
}
50+
51+
// Incompatible mode (read).
52+
SYCL_EXTERNAL void
53+
kernel6(local_accessor<const int, 1> &buf) SYCL_ESIMD_FUNCTION {
54+
simd<int, 32> v(0, 1);
55+
// CHECK: block_load_store.cpp:57{{.*}}error: no matching function
56+
// function for call to 'block_store'
57+
block_store<int, 32>(buf, 0, v);
58+
}

sycl/test/esimd/gather_scatter.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,23 @@ kernel5(accessor<int, 1, access::mode::read, access::target::device> &buf)
104104
// function for call to 'scatter'
105105
scatter_rgba(buf, offset, v);
106106
}
107+
108+
// Incompatible mode (read).
109+
SYCL_EXTERNAL void
110+
kernel6(local_accessor<const int, 1> &buf) SYCL_ESIMD_FUNCTION {
111+
simd<int, 32> v(0, 1);
112+
simd<uint32_t, 32> offset(0, 1);
113+
// CHECK: gather_scatter.cpp:115{{.*}}error: no matching function
114+
// function for call to 'scatter'
115+
scatter<int, 32>(buf, offset, v);
116+
}
117+
118+
// Incompatible mode (read).
119+
SYCL_EXTERNAL void
120+
kernel7(local_accessor<const int, 1> &buf) SYCL_ESIMD_FUNCTION {
121+
simd<int, 32 * 4> v(0, 1);
122+
simd<uint32_t, 32> offset(0, sizeof(int) * 4);
123+
// CHECK: gather_scatter.cpp:125{{.*}}error: no matching function
124+
// function for call to 'scatter'
125+
scatter_rgba(buf, offset, v);
126+
}

0 commit comments

Comments
 (0)