Skip to content

Commit 8c244de

Browse files
[SYCL] Fix accessor subscript ambiguity and reference type (#8033)
Accessor subscript operations may cause ambiguity when the data type is constant. This commit fixes this ambiguity and makes the reference member alias of accessors adhere to the SYCL 2020 definition of them. --------- Signed-off-by: Larsen, Steffen <[email protected]> Co-authored-by: Alexey Sachkov <[email protected]>
1 parent e616e81 commit 8c244de

File tree

2 files changed

+107
-29
lines changed

2 files changed

+107
-29
lines changed

sycl/include/sycl/accessor.hpp

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,9 @@ class accessor_common {
344344
}
345345

346346
template <int CurDims = SubDims,
347-
typename = detail::enable_if_t<CurDims == 1 && IsAccessAnyWrite>>
348-
RefType operator[](size_t Index) const {
347+
typename = detail::enable_if_t<
348+
CurDims == 1 && (IsAccessReadOnly || IsAccessAnyWrite)>>
349+
typename AccType::reference operator[](size_t Index) const {
349350
MIDs[Dims - CurDims] = Index;
350351
return MAccessor[MIDs];
351352
}
@@ -357,13 +358,6 @@ class accessor_common {
357358
MIDs[Dims - CurDims] = Index;
358359
return MAccessor[MIDs];
359360
}
360-
361-
template <int CurDims = SubDims,
362-
typename = detail::enable_if_t<CurDims == 1 && IsAccessReadOnly>>
363-
ConstRefType operator[](size_t Index) const {
364-
MIDs[Dims - SubDims] = Index;
365-
return MAccessor[MIDs];
366-
}
367361
};
368362
};
369363

@@ -1213,7 +1207,7 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
12131207
// otherwise
12141208
using value_type = typename std::conditional<AccessMode == access_mode::read,
12151209
const DataT, DataT>::type;
1216-
using reference = DataT &;
1210+
using reference = value_type &;
12171211
using const_reference = const DataT &;
12181212

12191213
using iterator = typename detail::accessor_iterator<value_type, Dimensions>;
@@ -1973,30 +1967,17 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
19731967
}
19741968

19751969
template <int Dims = Dimensions, typename RefT = RefType,
1976-
typename = detail::enable_if_t<Dims == 0 && IsAccessAnyWrite &&
1977-
!std::is_const<RefT>::value>>
1978-
operator RefType() const {
1970+
typename = detail::enable_if_t<Dims == 0 && (IsAccessAnyWrite ||
1971+
IsAccessReadOnly)>>
1972+
operator reference() const {
19791973
const size_t LinearIndex = getLinearIndex(id<AdjustedDim>());
19801974
return *(getQualifiedPtr() + LinearIndex);
19811975
}
19821976

19831977
template <int Dims = Dimensions,
1984-
typename = detail::enable_if_t<Dims == 0 && IsAccessReadOnly>>
1985-
operator ConstRefType() const {
1986-
const size_t LinearIndex = getLinearIndex(id<AdjustedDim>());
1987-
return *(getQualifiedPtr() + LinearIndex);
1988-
}
1989-
1990-
template <int Dims = Dimensions,
1991-
typename = detail::enable_if_t<(Dims > 0) && IsAccessAnyWrite>>
1992-
RefType operator[](id<Dimensions> Index) const {
1993-
const size_t LinearIndex = getLinearIndex(Index);
1994-
return getQualifiedPtr()[LinearIndex];
1995-
}
1996-
1997-
template <int Dims = Dimensions>
1998-
typename detail::enable_if_t<(Dims > 0) && IsAccessReadOnly, ConstRefType>
1999-
operator[](id<Dimensions> Index) const {
1978+
typename = detail::enable_if_t<(Dims > 0) && (IsAccessAnyWrite ||
1979+
IsAccessReadOnly)>>
1980+
reference operator[](id<Dimensions> Index) const {
20001981
const size_t LinearIndex = getLinearIndex(Index);
20011982
return getQualifiedPtr()[LinearIndex];
20021983
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: %clangxx -fsycl -fsyntax-only %s
2+
3+
// Test checks that the subscript operator and reference alias on accessors
4+
// evaluate to the right types.
5+
#include <sycl/sycl.hpp>
6+
7+
using namespace sycl;
8+
9+
// Trait for getting the return type of a full subscript operation.
10+
template <int Dims, typename AccT> struct FullSubscriptType;
11+
template <typename AccT> struct FullSubscriptType<1, AccT> {
12+
using type = decltype(std::declval<AccT>()[0]);
13+
};
14+
template <typename AccT> struct FullSubscriptType<2, AccT> {
15+
using type = decltype(std::declval<AccT>()[0][0]);
16+
};
17+
template <typename AccT> struct FullSubscriptType<3, AccT> {
18+
using type = decltype(std::declval<AccT>()[0][0][0]);
19+
};
20+
template <int Dims, typename AccT>
21+
using FullSubscriptTypeT = typename FullSubscriptType<Dims, AccT>::type;
22+
23+
// Expected reference type of an accessor given an access mode.
24+
template <access::mode AccessMode, typename DataT>
25+
using ExpectedRefTypeT = std::conditional_t<AccessMode == access::mode::read,
26+
std::add_const_t<DataT> &, DataT &>;
27+
28+
// Trait for getting the expected return type of a full subscript operation.
29+
template <access::mode AccessMode, access::target AccessTarget, typename DataT>
30+
struct ExpectedSubscriptType {
31+
using type = ExpectedRefTypeT<AccessMode, DataT>;
32+
};
33+
template <typename DataT, access::target AccessTarget>
34+
struct ExpectedSubscriptType<access::mode::atomic, AccessTarget, DataT> {
35+
using type = atomic<DataT, access::address_space::global_space>;
36+
};
37+
template <typename DataT>
38+
struct ExpectedSubscriptType<access::mode::atomic, access::target::local,
39+
DataT> {
40+
using type = atomic<DataT, access::address_space::local_space>;
41+
};
42+
template <access::mode AccessMode, access::target AccessTarget, typename DataT>
43+
using ExpectedSubscriptTypeT =
44+
typename ExpectedSubscriptType<AccessMode, AccessTarget, DataT>::type;
45+
46+
template <typename DataT, int Dims, access::mode AccessMode,
47+
access::target AccessTarget, typename AccT>
48+
void CheckAccRefAndSubscript() {
49+
static_assert(std::is_same_v<typename AccT::reference,
50+
ExpectedRefTypeT<AccessMode, DataT>>);
51+
static_assert(
52+
std::is_same_v<FullSubscriptTypeT<Dims, AccT>,
53+
ExpectedSubscriptTypeT<AccessMode, AccessTarget, DataT>>);
54+
}
55+
56+
template <typename DataT, int Dims, access::mode AccessMode> void CheckAcc() {
57+
CheckAccRefAndSubscript<DataT, Dims, AccessMode, access::target::host_buffer,
58+
host_accessor<DataT, Dims, AccessMode>>();
59+
CheckAccRefAndSubscript<
60+
DataT, Dims, AccessMode, access::target::device,
61+
accessor<DataT, Dims, AccessMode, access::target::device>>();
62+
CheckAccRefAndSubscript<
63+
DataT, Dims, AccessMode, access::target::host_buffer,
64+
accessor<DataT, Dims, AccessMode, access::target::host_buffer>>();
65+
if constexpr (AccessMode == access::mode::read_write) {
66+
CheckAccRefAndSubscript<DataT, Dims, AccessMode, access::target::local,
67+
local_accessor<DataT, Dims>>();
68+
}
69+
if constexpr (AccessMode == access::mode::read_write ||
70+
AccessMode == access::mode::atomic) {
71+
CheckAccRefAndSubscript<
72+
DataT, Dims, AccessMode, access::target::local,
73+
accessor<DataT, Dims, AccessMode, access::target::local>>();
74+
}
75+
}
76+
77+
template <typename DataT, access::mode AccessMode> void CheckAccAllDims() {
78+
CheckAcc<DataT, 1, AccessMode>();
79+
CheckAcc<DataT, 2, AccessMode>();
80+
CheckAcc<DataT, 3, AccessMode>();
81+
}
82+
83+
template <typename DataT> void CheckAccAllAccessModesAndDims() {
84+
CheckAccAllDims<DataT, access::mode::write>();
85+
CheckAccAllDims<DataT, access::mode::read_write>();
86+
CheckAccAllDims<DataT, access::mode::discard_write>();
87+
CheckAccAllDims<DataT, access::mode::discard_read_write>();
88+
CheckAccAllDims<DataT, access::mode::read>();
89+
if constexpr (!std::is_const_v<DataT>)
90+
CheckAccAllDims<DataT, access::mode::atomic>();
91+
}
92+
93+
int main() {
94+
CheckAccAllAccessModesAndDims<int>();
95+
CheckAccAllAccessModesAndDims<const int>();
96+
return 0;
97+
}

0 commit comments

Comments
 (0)