Skip to content

Commit b4d2b2c

Browse files
[SYCL] Fix empty zero-dimensional accessor access range (#10156)
The current implementation of zero-dimensional accessors always assume that its size is 1, but this is not the case if the underlying buffer is empty. This commit fixes this behavior by correcting the used access range for the accessor, making it correctly report the right sizes and iterators. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent d3b5b59 commit b4d2b2c

File tree

2 files changed

+102
-12
lines changed

2 files changed

+102
-12
lines changed

sycl/include/sycl/accessor.hpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ template <> struct IsCxPropertyList<ext::oneapi::accessor_property_list<>> {
276276
constexpr static bool value = false;
277277
};
278278

279+
// Zero-dimensional accessors references at-most a single element, so the range
280+
// is either 0 if the associated buffer is empty or 1 otherwise.
281+
template <typename BufferT>
282+
sycl::range<1> GetZeroDimAccessRange(BufferT Buffer) {
283+
return std::min(Buffer.size(), size_t{1});
284+
}
285+
279286
__SYCL_EXPORT device getDeviceFromHandler(handler &CommandGroupHandlerRef);
280287

281288
template <typename DataT, int Dimensions, access::mode AccessMode,
@@ -1512,11 +1519,14 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
15121519
const property_list &PropertyList = {},
15131520
const detail::code_location CodeLoc = detail::code_location::current())
15141521
#ifdef __SYCL_DEVICE_ONLY__
1515-
: impl(id<AdjustedDim>(), range<1>{1}, BufferRef.get_range()) {
1522+
: impl(id<AdjustedDim>(), detail::GetZeroDimAccessRange(BufferRef),
1523+
BufferRef.get_range()) {
15161524
(void)PropertyList;
15171525
#else
15181526
: AccessorBaseHost(
1519-
/*Offset=*/{0, 0, 0}, detail::convertToArrayOfN<3, 1>(range<1>{1}),
1527+
/*Offset=*/{0, 0, 0},
1528+
detail::convertToArrayOfN<3, 1>(
1529+
detail::GetZeroDimAccessRange(BufferRef)),
15201530
detail::convertToArrayOfN<3, 1>(BufferRef.get_range()),
15211531
getAdjustedMode(PropertyList),
15221532
detail::getSyclObjImpl(BufferRef).get(), AdjustedDim, sizeof(DataT),
@@ -1548,11 +1558,14 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
15481558
{},
15491559
const detail::code_location CodeLoc = detail::code_location::current())
15501560
#ifdef __SYCL_DEVICE_ONLY__
1551-
: impl(id<AdjustedDim>(), range<1>{1}, BufferRef.get_range()) {
1561+
: impl(id<AdjustedDim>(), detail::GetZeroDimAccessRange(BufferRef),
1562+
BufferRef.get_range()) {
15521563
(void)PropertyList;
15531564
#else
15541565
: AccessorBaseHost(
1555-
/*Offset=*/{0, 0, 0}, detail::convertToArrayOfN<3, 1>(range<1>{1}),
1566+
/*Offset=*/{0, 0, 0},
1567+
detail::convertToArrayOfN<3, 1>(
1568+
detail::GetZeroDimAccessRange(BufferRef)),
15561569
detail::convertToArrayOfN<3, 1>(BufferRef.get_range()),
15571570
getAdjustedMode(PropertyList),
15581571
detail::getSyclObjImpl(BufferRef).get(), AdjustedDim, sizeof(DataT),
@@ -1579,13 +1592,16 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
15791592
const property_list &PropertyList = {},
15801593
const detail::code_location CodeLoc = detail::code_location::current())
15811594
#ifdef __SYCL_DEVICE_ONLY__
1582-
: impl(id<AdjustedDim>(), range<1>{1}, BufferRef.get_range()) {
1595+
: impl(id<AdjustedDim>(), detail::GetZeroDimAccessRange(BufferRef),
1596+
BufferRef.get_range()) {
15831597
(void)CommandGroupHandler;
15841598
(void)PropertyList;
15851599
}
15861600
#else
15871601
: AccessorBaseHost(
1588-
/*Offset=*/{0, 0, 0}, detail::convertToArrayOfN<3, 1>(range<1>{1}),
1602+
/*Offset=*/{0, 0, 0},
1603+
detail::convertToArrayOfN<3, 1>(
1604+
detail::GetZeroDimAccessRange(BufferRef)),
15891605
detail::convertToArrayOfN<3, 1>(BufferRef.get_range()),
15901606
getAdjustedMode(PropertyList),
15911607
detail::getSyclObjImpl(BufferRef).get(), Dimensions, sizeof(DataT),
@@ -1612,13 +1628,16 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
16121628
{},
16131629
const detail::code_location CodeLoc = detail::code_location::current())
16141630
#ifdef __SYCL_DEVICE_ONLY__
1615-
: impl(id<AdjustedDim>(), range<1>{1}, BufferRef.get_range()) {
1631+
: impl(id<AdjustedDim>(), detail::GetZeroDimAccessRange(BufferRef),
1632+
BufferRef.get_range()) {
16161633
(void)CommandGroupHandler;
16171634
(void)PropertyList;
16181635
}
16191636
#else
16201637
: AccessorBaseHost(
1621-
/*Offset=*/{0, 0, 0}, detail::convertToArrayOfN<3, 1>(range<1>{1}),
1638+
/*Offset=*/{0, 0, 0},
1639+
detail::convertToArrayOfN<3, 1>(
1640+
detail::GetZeroDimAccessRange(BufferRef)),
16221641
detail::convertToArrayOfN<3, 1>(BufferRef.get_range()),
16231642
getAdjustedMode(PropertyList),
16241643
detail::getSyclObjImpl(BufferRef).get(), Dimensions, sizeof(DataT),
@@ -2416,10 +2435,7 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
24162435
private:
24172436
template <int Dims, typename = std::enable_if_t<(Dims > 0)>>
24182437
range<Dims> getRange() const {
2419-
if constexpr (Dimensions == 0)
2420-
return range<1>{1};
2421-
else
2422-
return detail::convertToArrayOfN<Dims, 1>(getAccessRange());
2438+
return detail::convertToArrayOfN<AdjustedDim, 1>(getAccessRange());
24232439
}
24242440

24252441
template <int Dims = Dimensions, typename = std::enable_if_t<(Dims > 0)>>
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
// Disabled for HIP (https://github.com/intel/llvm/issues/10358) and CUDA
5+
// (https://github.com/intel/llvm/issues/10360)
6+
// UNSUPPORTED: cuda || hip
7+
8+
// Tests the size and iterator members of an empty zero-dimensional accessor.
9+
10+
#include <sycl/sycl.hpp>
11+
12+
using namespace sycl;
13+
14+
int check_host(bool CheckResult, std::string Msg) {
15+
if (!CheckResult)
16+
std::cout << "Case failed: " << Msg << std::endl;
17+
return !CheckResult;
18+
}
19+
20+
int main() {
21+
int Failures = 0;
22+
queue Q;
23+
24+
buffer<int, 1> EmptyBuf{0};
25+
assert(EmptyBuf.size() == 0);
26+
27+
{
28+
host_accessor<int, 0> EmptyHostAcc{EmptyBuf};
29+
Failures += check_host(EmptyHostAcc.empty(), "empty() on host_accesor");
30+
Failures += check_host(EmptyHostAcc.size() == 0, "size() on host_accesor");
31+
Failures += check_host(EmptyHostAcc.byte_size() == 0,
32+
"byte_size() on host_accesor");
33+
Failures +=
34+
check_host(EmptyHostAcc.max_size() == 0, "max_size() on host_accesor");
35+
Failures += check_host(EmptyHostAcc.begin() == EmptyHostAcc.end(),
36+
"begin()/end() on host_accesor");
37+
Failures += check_host(EmptyHostAcc.cbegin() == EmptyHostAcc.cend(),
38+
"cbegin()/cend() on host_accesor");
39+
Failures += check_host(EmptyHostAcc.rbegin() == EmptyHostAcc.rend(),
40+
"rbegin()/rend() on host_accesor");
41+
Failures += check_host(EmptyHostAcc.crbegin() == EmptyHostAcc.crend(),
42+
"crbegin()/crend() on host_accesor");
43+
}
44+
45+
bool DeviceResults[8] = {false};
46+
{
47+
buffer<bool, 1> DeviceResultsBuf{DeviceResults, range<1>{8}};
48+
Q.submit([&](handler &CGH) {
49+
accessor<int, 0> EmptyDevAcc{EmptyBuf, CGH};
50+
accessor DeviceResultsAcc{DeviceResultsBuf, CGH};
51+
CGH.single_task([=]() {
52+
DeviceResultsAcc[0] = EmptyDevAcc.empty();
53+
DeviceResultsAcc[1] = EmptyDevAcc.size() == 0;
54+
DeviceResultsAcc[2] = EmptyDevAcc.byte_size() == 0;
55+
DeviceResultsAcc[3] = EmptyDevAcc.max_size() == 0;
56+
DeviceResultsAcc[4] = EmptyDevAcc.begin() == EmptyDevAcc.end();
57+
DeviceResultsAcc[5] = EmptyDevAcc.cbegin() == EmptyDevAcc.cend();
58+
DeviceResultsAcc[6] = EmptyDevAcc.rbegin() == EmptyDevAcc.rend();
59+
DeviceResultsAcc[7] = EmptyDevAcc.crbegin() == EmptyDevAcc.crend();
60+
});
61+
});
62+
}
63+
64+
Failures += check_host(DeviceResults[0], "empty() on accessor");
65+
Failures += check_host(DeviceResults[1], "size() on accessor");
66+
Failures += check_host(DeviceResults[2], "byte_size() on accessor");
67+
Failures += check_host(DeviceResults[3], "max_size() on accessor");
68+
Failures += check_host(DeviceResults[4], "begin()/end() on accessor");
69+
Failures += check_host(DeviceResults[5], "cbegin()/cend() on accessor");
70+
Failures += check_host(DeviceResults[6], "rbegin()/rend() on accessor");
71+
Failures += check_host(DeviceResults[7], "crbegin()/crend() on accessor");
72+
73+
return Failures;
74+
}

0 commit comments

Comments
 (0)