Skip to content

Commit a91d722

Browse files
v-klochkovbader
authored andcommitted
[SYCL] Fix several errors in accessor
This patch fixes 2 constructors, fixes get_count() and get_size() methods, adds operator== and operator!=, fixes hash implementation, adds static_asssert to range and id classes to have early check for wrong dimension, fixes few other errors. Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 8d4be77 commit a91d722

File tree

8 files changed

+108
-76
lines changed

8 files changed

+108
-76
lines changed

sycl/include/CL/sycl/accessor2.hpp

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ class accessor_common {
212212
MIDs[0] = Index;
213213
}
214214

215-
template <int CurDims = SubDims, typename = enable_if_t<CurDims != 1>>
215+
template <int CurDims = SubDims, typename = enable_if_t<(CurDims > 1)>>
216216
AccessorSubscript<CurDims - 1> operator[](size_t Index) {
217217
MIDs[Dims - CurDims] = Index;
218218
return AccessorSubscript<CurDims - 1>(MAccessor, MIDs);
@@ -282,19 +282,19 @@ class accessor :
282282

283283
size_t Result = 0;
284284
for (int I = 0; I < Dims; ++I)
285-
Result = Result * getOrigRange()[I] + getOffset()[I] + Id[I];
285+
Result = Result * getMemoryRange()[I] + getOffset()[I] + Id[I];
286286
return Result;
287287
}
288288

289289
#ifdef __SYCL_DEVICE_ONLY__
290290

291291
id<AdjustedDim> &getOffset() { return impl.Offset; }
292-
range<AdjustedDim> &getRange() { return impl.AccessRange; }
293-
range<AdjustedDim> &getOrigRange() { return impl.MemRange; }
292+
range<AdjustedDim> &getAccessRange() { return impl.AccessRange; }
293+
range<AdjustedDim> &getMemoryRange() { return impl.MemRange; }
294294

295295
const id<AdjustedDim> &getOffset() const { return impl.Offset; }
296-
const range<AdjustedDim> &getRange() const { return impl.AccessRange; }
297-
const range<AdjustedDim> &getOrigRange() const { return impl.MemRange; }
296+
const range<AdjustedDim> &getAccessRange() const { return impl.AccessRange; }
297+
const range<AdjustedDim> &getMemoryRange() const { return impl.MemRange; }
298298

299299
detail::AccessorImplDevice<AdjustedDim> impl;
300300

@@ -305,8 +305,8 @@ class accessor :
305305
MData = Ptr;
306306
for (int I = 0; I < AdjustedDim; ++I) {
307307
getOffset()[I] = Offset[I];
308-
getRange()[I] = AccessRange[I];
309-
getOrigRange()[I] = MemRange[I];
308+
getAccessRange()[I] = AccessRange[I];
309+
getMemoryRange()[I] = MemRange[I];
310310
}
311311
// In case of 1D buffer, adjust pointer during initialization rather
312312
// then each time in operator[] or get_pointer functions.
@@ -317,9 +317,9 @@ class accessor :
317317
PtrType getQualifiedPtr() const { return MData; }
318318
#else
319319

320-
using AccessorBaseHost::getRange;
320+
using AccessorBaseHost::getAccessRange;
321321
using AccessorBaseHost::getOffset;
322-
using AccessorBaseHost::getOrigRange;
322+
using AccessorBaseHost::getMemoryRange;
323323

324324
char padding[sizeof(detail::AccessorImplDevice<AdjustedDim>) +
325325
sizeof(PtrType) - sizeof(detail::AccessorBaseHost)];
@@ -332,17 +332,15 @@ class accessor :
332332

333333
public:
334334
using value_type = DataT;
335-
using reference = typename detail::PtrValueType<DataT, AS>::type &;
335+
using reference = DataT &;
336336
using const_reference = const reference;
337337

338338
template <int Dims = Dimensions>
339-
accessor(buffer<DataT, 1> &BufferRef,
340-
enable_if_t<((!IsPlaceH && IsHostBuf) ||
339+
accessor(enable_if_t<((!IsPlaceH && IsHostBuf) ||
341340
(IsPlaceH && (IsGlobalBuf || IsConstantBuf))) &&
342-
Dims == 0>)
341+
Dims == 0, buffer<DataT, 1> > &BufferRef)
343342
#ifdef __SYCL_DEVICE_ONLY__
344-
: impl(id<AdjustedDim>(), BufferRef.get_range(), BufferRef.MemRange)
345-
{
343+
: impl(id<AdjustedDim>(), BufferRef.get_range(), BufferRef.MemRange) {
346344
#else
347345
: AccessorBaseHost(
348346
/*Offset=*/{0, 0, 0},
@@ -351,15 +349,17 @@ class accessor :
351349
detail::getSyclObjImpl(BufferRef).get(), AdjustedDim,
352350
sizeof(DataT)) {
353351
detail::EventImplPtr Event =
354-
detail::Scheduler::getInstance().addHostAccessor(this);
352+
detail::Scheduler::getInstance().addHostAccessor(
353+
AccessorBaseHost::impl.get());
355354
Event->wait(Event);
356355
#endif
357356
}
358357

359358
template <int Dims = Dimensions>
360359
accessor(
361-
buffer<DataT, 1> &BufferRef, handler &CommandGroupHandler,
362-
enable_if_t<(!IsPlaceH && (IsGlobalBuf || IsConstantBuf)) && Dims == 0>)
360+
buffer<DataT, 1> &BufferRef,
361+
enable_if_t<(!IsPlaceH && (IsGlobalBuf || IsConstantBuf)) && Dims == 0,
362+
handler> &CommandGroupHandler)
363363
#ifdef __SYCL_DEVICE_ONLY__
364364
: impl(id<AdjustedDim>(), BufferRef.get_range(), BufferRef.MemRange) {
365365
}
@@ -458,13 +458,13 @@ class accessor :
458458

459459
constexpr bool is_placeholder() const { return IsPlaceH; }
460460

461-
size_t get_size() const { return getRange().size() * sizeof(DataT); }
461+
size_t get_size() const { return getMemoryRange().size() * sizeof(DataT); }
462462

463-
size_t get_count() const { return getRange().size(); }
463+
size_t get_count() const { return getMemoryRange().size(); }
464464

465465
template <int Dims = Dimensions, typename = enable_if_t<(Dims > 0)>>
466466
range<Dimensions> get_range() const {
467-
return detail::convertToArrayOfN<Dimensions, 1>(getRange());
467+
return detail::convertToArrayOfN<Dimensions, 1>(getAccessRange());
468468
}
469469

470470
template <int Dims = Dimensions, typename = enable_if_t<(Dims > 0)>>
@@ -475,7 +475,7 @@ class accessor :
475475
template <int Dims = Dimensions,
476476
typename = enable_if_t<IsAccessAnyWrite && Dims == 0>>
477477
operator RefType() const {
478-
const size_t LinearIndex = getLinearIndex(id<Dimensions>());
478+
const size_t LinearIndex = getLinearIndex(id<AdjustedDim>());
479479
return *(getQualifiedPtr() + LinearIndex);
480480
}
481481

@@ -570,6 +570,9 @@ class accessor :
570570
const size_t LinearIndex = getLinearIndex(id<AdjustedDim>());
571571
return constant_ptr<DataT>(getQualifiedPtr() + LinearIndex);
572572
}
573+
574+
bool operator==(const accessor &Rhs) const { return impl == Rhs.impl; }
575+
bool operator!=(const accessor &Rhs) const { return !(*this == Rhs); }
573576
};
574577

575578
// Local accessor
@@ -604,8 +607,8 @@ class accessor<DataT, Dimensions, AccessMode, access::target::local,
604607
#ifdef __SYCL_DEVICE_ONLY__
605608
detail::LocalAccessorBaseDevice<AdjustedDim> impl;
606609

607-
sycl::range<AdjustedDim> &getSize() { return impl.AccessRange; }
608-
const sycl::range<AdjustedDim> &getSize() const { return impl.AccessRange; }
610+
sycl::range<AdjustedDim> &getSize() { return impl.MemRange; }
611+
const sycl::range<AdjustedDim> &getSize() const { return impl.MemRange; }
609612

610613
void __init(PtrType Ptr, range<AdjustedDim> AccessRange,
611614
range<AdjustedDim> MemRange, id<AdjustedDim> Offset) {
@@ -719,6 +722,9 @@ class accessor<DataT, Dimensions, AccessMode, access::target::local,
719722
local_ptr<DataT> get_pointer() const {
720723
return local_ptr<DataT>(getQualifiedPtr());
721724
}
725+
726+
bool operator==(const accessor &Rhs) const { return impl == Rhs.impl; }
727+
bool operator!=(const accessor &Rhs) const { return !(*this == Rhs); }
722728
};
723729

724730
// Image accessor
@@ -815,9 +821,10 @@ struct hash<cl::sycl::accessor<DataT, Dimensions, AccessMode, AccessTarget,
815821
// Hash is not supported on DEVICE. Just return 0 here.
816822
return 0;
817823
#else
818-
std::shared_ptr<cl::sycl::detail::AccessorBaseHost> AccBaseImplPtr =
819-
cl::sycl::detail::getSyclObjImpl(A);
820-
return hash<decltype(AccBaseImplPtr)>()(AccBaseImplPtr);
824+
// getSyclObjImpl() here returns a pointer to either AccessorImplHost
825+
// or LocalAccessorImplHost depending on the AccessTarget.
826+
auto AccImplPtr = cl::sycl::detail::getSyclObjImpl(A);
827+
return hash<decltype(AccImplPtr)>()(AccImplPtr);
821828
#endif
822829
}
823830
};

sycl/include/CL/sycl/detail/accessor_impl.hpp

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,19 @@ namespace detail {
2727

2828
template <int Dims> class AccessorImplDevice {
2929
public:
30-
AccessorImplDevice(id<Dims> Offset, range<Dims> Range, range<Dims> OrigRange)
31-
: Offset(Offset), AccessRange(Range), MemRange(OrigRange) {}
30+
AccessorImplDevice(id<Dims> Offset, range<Dims> AccessRange,
31+
range<Dims> MemoryRange)
32+
: Offset(Offset), AccessRange(AccessRange), MemRange(MemoryRange) {}
3233

3334
id<Dims> Offset;
3435
range<Dims> AccessRange;
3536
range<Dims> MemRange;
37+
38+
bool operator==(const AccessorImplDevice &Rhs) const {
39+
return (Offset == Rhs.Offset &&
40+
AccessRange == Rhs.AccessRange &&
41+
MemRange == Rhs.MemRange);
42+
}
3643
};
3744

3845
template <int Dims> class LocalAccessorBaseDevice {
@@ -43,14 +50,18 @@ template <int Dims> class LocalAccessorBaseDevice {
4350
range<Dims> AccessRange;
4451
range<Dims> MemRange;
4552
id<Dims> Offset;
53+
54+
bool operator==(const LocalAccessorBaseDevice &Rhs) const {
55+
return (AccessRange == Rhs.AccessRange);
56+
}
4657
};
4758

4859
class AccessorImplHost {
4960
public:
50-
AccessorImplHost(id<3> Offset, range<3> Range, range<3> OrigRange,
61+
AccessorImplHost(id<3> Offset, range<3> AccessRange, range<3> MemoryRange,
5162
access::mode AccessMode, detail::SYCLMemObjT *SYCLMemObject,
5263
int Dims, int ElemSize)
53-
: MOffset(Offset), MRange(Range), MOrigRange(OrigRange),
64+
: MOffset(Offset), MAccessRange(AccessRange), MMemoryRange(MemoryRange),
5465
MAccessMode(AccessMode), MSYCLMemObj(SYCLMemObject), MDims(Dims),
5566
MElemSize(ElemSize) {}
5667

@@ -59,16 +70,16 @@ class AccessorImplHost {
5970
BlockingEvent->setComplete();
6071
}
6172
AccessorImplHost(const AccessorImplHost &Other)
62-
: MOffset(Other.MOffset), MRange(Other.MRange),
63-
MOrigRange(Other.MOrigRange), MAccessMode(Other.MAccessMode),
73+
: MOffset(Other.MOffset), MAccessRange(Other.MAccessRange),
74+
MMemoryRange(Other.MMemoryRange), MAccessMode(Other.MAccessMode),
6475
MSYCLMemObj(Other.MSYCLMemObj), MDims(Other.MDims),
6576
MElemSize(Other.MElemSize) {}
6677

6778
id<3> MOffset;
6879
// The size of accessing region.
69-
range<3> MRange;
80+
range<3> MAccessRange;
7081
// The size of memory object this requirement is created for.
71-
range<3> MOrigRange;
82+
range<3> MMemoryRange;
7283
access::mode MAccessMode;
7384

7485
detail::SYCLMemObjT *MSYCLMemObj;
@@ -85,22 +96,23 @@ using AccessorImplPtr = std::shared_ptr<AccessorImplHost>;
8596

8697
class AccessorBaseHost {
8798
public:
88-
AccessorBaseHost(id<3> Offset, range<3> Range, range<3> OrigRange,
99+
AccessorBaseHost(id<3> Offset, range<3> AccessRange, range<3> MemoryRange,
89100
access::mode AccessMode, detail::SYCLMemObjT *SYCLMemObject,
90101
int Dims, int ElemSize) {
91-
impl = std::make_shared<AccessorImplHost>(
92-
Offset, Range, OrigRange, AccessMode, SYCLMemObject, Dims, ElemSize);
102+
impl = std::make_shared<AccessorImplHost>(Offset, AccessRange, MemoryRange,
103+
AccessMode, SYCLMemObject,
104+
Dims, ElemSize);
93105
}
94106

95107
protected:
96108
id<3> &getOffset() { return impl->MOffset; }
97-
range<3> &getRange() { return impl->MRange; }
98-
range<3> &getOrigRange() { return impl->MOrigRange; }
109+
range<3> &getAccessRange() { return impl->MAccessRange; }
110+
range<3> &getMemoryRange() { return impl->MMemoryRange; }
99111
void *getPtr() { return impl->MData; }
100112

101113
const id<3> &getOffset() const { return impl->MOffset; }
102-
const range<3> &getRange() const { return impl->MRange; }
103-
const range<3> &getOrigRange() const { return impl->MOrigRange; }
114+
const range<3> &getAccessRange() const { return impl->MAccessRange; }
115+
const range<3> &getMemoryRange() const { return impl->MMemoryRange; }
104116
void *getPtr() const { return const_cast<void *>(impl->MData); }
105117

106118
template <class Obj>
@@ -135,7 +147,10 @@ class LocalAccessorBaseHost {
135147

136148
int getNumOfDims() { return impl->MDims; }
137149
int getElementSize() { return impl->MElemSize; }
150+
138151
protected:
152+
template <class Obj>
153+
friend decltype(Obj::impl) detail::getSyclObjImpl(const Obj &SyclObject);
139154

140155
std::shared_ptr<LocalAccessorImplHost> impl;
141156
};

sycl/include/CL/sycl/detail/array.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ template <int dimensions> class range;
1919
namespace detail {
2020

2121
template <int dimensions = 1> class array {
22+
static_assert(dimensions >= 1, "Array cannot be 0-dimensional.");
2223
public:
2324
array() : common_array{0} {}
2425

sycl/include/CL/sycl/handler2.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,11 @@ class handler {
234234

235235
MArgs.emplace_back(
236236
detail::ArgDesc(detail::kernel_param_kind_t::kind_std_layout,
237-
&(AccImpl->MRange[0]),
237+
&(AccImpl->MAccessRange[0]),
238238
sizeof(size_t) * AccImpl->MDims, NextArgId + 1));
239239
MArgs.emplace_back(
240240
detail::ArgDesc(detail::kernel_param_kind_t::kind_std_layout,
241-
&AccImpl->MOrigRange[0],
241+
&AccImpl->MMemoryRange[0],
242242
sizeof(size_t) * AccImpl->MDims, NextArgId + 2));
243243
MArgs.emplace_back(
244244
detail::ArgDesc(detail::kernel_param_kind_t::kind_std_layout,

sycl/include/CL/sycl/id.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ template <int dimensions> class range;
1818
template <int dimensions = 1> struct id : public detail::array<dimensions> {
1919
private:
2020
using base = detail::array<dimensions>;
21+
static_assert(dimensions >= 1 && dimensions <= 3,
22+
"id can only be 1, 2, or 3 dimentional.");
2123
public:
2224
id() = default;
2325

sycl/include/CL/sycl/range.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ namespace sycl {
1616
template <int dimensions> struct id;
1717
template <int dimensions = 1>
1818
class range : public detail::array<dimensions> {
19+
static_assert(dimensions >= 1 && dimensions <= 3,
20+
"range can only be 1, 2, or 3 dimentional.");
1921
using base = detail::array<dimensions>;
2022
public:
2123
/* The following constructor is only available in the range class

0 commit comments

Comments
 (0)