Skip to content

Commit 225b3ec

Browse files
[SYCL] Fix weak_object for host_accessor and stream (#8903)
This commit makes a specialization of weak_object for stream, similar to buffer, to help recreate a stream from its constituents. Additionally it fixes a problem for host_accessor where the bases would conflict with eachother as they defined the same functions. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 9ebe3cf commit 225b3ec

File tree

6 files changed

+161
-9
lines changed

6 files changed

+161
-9
lines changed

sycl/include/sycl/accessor.hpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,9 +3060,7 @@ template <typename DataT, int Dimensions = 1,
30603060
access_mode AccessMode = access_mode::read_write>
30613061
class __SYCL_EBO host_accessor
30623062
: public accessor<DataT, Dimensions, AccessMode, target::host_buffer,
3063-
access::placeholder::false_t>,
3064-
public detail::OwnerLessBase<
3065-
host_accessor<DataT, Dimensions, AccessMode>> {
3063+
access::placeholder::false_t> {
30663064
protected:
30673065
using AccessorT = accessor<DataT, Dimensions, AccessMode, target::host_buffer,
30683066
access::placeholder::false_t>;
@@ -3083,6 +3081,18 @@ class __SYCL_EBO host_accessor
30833081
AccessorT::__init(Ptr, AccessRange, MemRange, Offset);
30843082
}
30853083

3084+
#ifndef __SYCL_DEVICE_ONLY__
3085+
host_accessor(const detail::AccessorImplPtr &Impl)
3086+
: accessor<DataT, Dimensions, AccessMode, target::host_buffer,
3087+
access::placeholder::false_t>{Impl} {}
3088+
3089+
template <class Obj>
3090+
friend decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject);
3091+
3092+
template <class T>
3093+
friend T detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);
3094+
#endif // __SYCL_DEVICE_ONLY__
3095+
30863096
public:
30873097
host_accessor() : AccessorT() {}
30883098

@@ -3240,6 +3250,28 @@ class __SYCL_EBO host_accessor
32403250
*AccessorT::getQualifiedPtr() = std::move(Other);
32413251
return *this;
32423252
}
3253+
3254+
// host_accessor needs to explicitly define the owner_before member functions
3255+
// as inheriting from OwnerLessBase causes base class conflicts.
3256+
// TODO: Once host_accessor is detached from accessor, inherit from
3257+
// OwnerLessBase instead.
3258+
#ifndef __SYCL_DEVICE_ONLY__
3259+
bool ext_oneapi_owner_before(
3260+
const ext::oneapi::detail::weak_object_base<host_accessor> &Other)
3261+
const noexcept {
3262+
return this->impl.owner_before(
3263+
ext::oneapi::detail::getSyclWeakObjImpl(Other));
3264+
}
3265+
3266+
bool ext_oneapi_owner_before(const host_accessor &Other) const noexcept {
3267+
return this->impl.owner_before(Other.impl);
3268+
}
3269+
#else
3270+
bool ext_oneapi_owner_before(
3271+
const ext::oneapi::detail::weak_object_base<host_accessor> &Other)
3272+
const noexcept;
3273+
bool ext_oneapi_owner_before(const host_accessor &Other) const noexcept;
3274+
#endif
32433275
};
32443276

32453277
template <typename DataT, int Dimensions, typename AllocatorT>

sycl/include/sycl/buffer.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,9 @@ class buffer : public detail::buffer_plain,
722722
detail::make_buffer_helper(pi_native_handle, const context &, event, bool);
723723
template <typename SYCLObjT> friend class ext::oneapi::weak_object;
724724

725+
// NOTE: These members are required for reconstructing the buffer, but are not
726+
// part of the implementation class. If more members are added, they should
727+
// also be added to the weak_object specialization for buffers.
725728
range<dimensions> Range;
726729
// Offset field specifies the origin of the sub buffer inside the parent
727730
// buffer

sycl/include/sycl/ext/oneapi/owner_less.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ struct owner_less<kernel_id> : public detail::owner_less_base<kernel_id> {};
6161
template <>
6262
struct owner_less<platform> : public detail::owner_less_base<platform> {};
6363
template <> struct owner_less<queue> : public detail::owner_less_base<queue> {};
64+
template <>
65+
struct owner_less<stream> : public detail::owner_less_base<stream> {};
6466

6567
template <bundle_state State>
6668
struct owner_less<device_image<State>>
@@ -86,10 +88,6 @@ struct owner_less<host_accessor<DataT, Dimensions, AccessMode>>
8688
: public detail::owner_less_base<
8789
host_accessor<DataT, Dimensions, AccessMode>> {};
8890

89-
template <typename DataT, int Dimensions>
90-
struct owner_less<host_accessor<DataT, Dimensions>>
91-
: public detail::owner_less_base<host_accessor<DataT, Dimensions>> {};
92-
9391
template <typename DataT, int Dimensions>
9492
struct owner_less<local_accessor<DataT, Dimensions>>
9593
: public detail::owner_less_base<local_accessor<DataT, Dimensions>> {};

sycl/include/sycl/ext/oneapi/weak_object.hpp

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <sycl/buffer.hpp>
1212
#include <sycl/detail/defines_elementary.hpp>
1313
#include <sycl/ext/oneapi/weak_object_base.hpp>
14+
#include <sycl/stream.hpp>
1415

1516
#include <optional>
1617

@@ -50,7 +51,8 @@ class weak_object : public detail::weak_object_base<SYCLObjT> {
5051

5152
weak_object &operator=(const SYCLObjT &SYCLObj) noexcept {
5253
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
53-
this->MObjWeakPtr = GetWeakImpl(SYCLObj);
54+
this->MObjWeakPtr =
55+
detail::weak_object_base<SYCLObjT>::GetWeakImpl(SYCLObj);
5456
return *this;
5557
}
5658
weak_object &operator=(const weak_object &Other) noexcept = default;
@@ -103,7 +105,8 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
103105

104106
weak_object &operator=(const buffer_type &SYCLObj) noexcept {
105107
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
106-
this->MObjWeakPtr = GetWeakImpl(SYCLObj);
108+
this->MObjWeakPtr = detail::weak_object_base<
109+
buffer<T, Dimensions, AllocatorT>>::GetWeakImpl(SYCLObj);
107110
this->MRange = SYCLObj.Range;
108111
this->MOffsetInBytes = SYCLObj.OffsetInBytes;
109112
this->MIsSubBuffer = SYCLObj.IsSubBuffer;
@@ -112,6 +115,13 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
112115
weak_object &operator=(const weak_object &Other) noexcept = default;
113116
weak_object &operator=(weak_object &&Other) noexcept = default;
114117

118+
void swap(weak_object &Other) noexcept {
119+
this->MObjWeakPtr.swap(Other.MObjWeakPtr);
120+
std::swap(MRange, Other.MRange);
121+
std::swap(MOffsetInBytes, Other.MOffsetInBytes);
122+
std::swap(MIsSubBuffer, Other.MIsSubBuffer);
123+
}
124+
115125
#ifndef __SYCL_DEVICE_ONLY__
116126
std::optional<buffer_type> try_lock() const noexcept {
117127
auto MObjImplPtr = this->MObjWeakPtr.lock();
@@ -141,6 +151,78 @@ class weak_object<buffer<T, Dimensions, AllocatorT>>
141151
bool MIsSubBuffer;
142152
};
143153

154+
// Specialization of weak_object for stream as it needs additional members
155+
// to reconstruct the original stream.
156+
template <>
157+
class weak_object<stream> : public detail::weak_object_base<stream> {
158+
public:
159+
using object_type = typename detail::weak_object_base<stream>::object_type;
160+
161+
constexpr weak_object() noexcept : detail::weak_object_base<stream>() {}
162+
weak_object(const stream &SYCLObj) noexcept
163+
: detail::weak_object_base<stream>(SYCLObj),
164+
MWeakGlobalBuf{SYCLObj.GlobalBuf},
165+
MWeakGlobalOffset{SYCLObj.GlobalOffset},
166+
MWeakGlobalFlushBuf{SYCLObj.GlobalFlushBuf} {}
167+
weak_object(const weak_object &Other) noexcept = default;
168+
weak_object(weak_object &&Other) noexcept = default;
169+
170+
weak_object &operator=(const stream &SYCLObj) noexcept {
171+
// Create weak_ptr from the shared_ptr to SYCLObj's implementation object.
172+
this->MObjWeakPtr = detail::weak_object_base<stream>::GetWeakImpl(SYCLObj);
173+
MWeakGlobalBuf = SYCLObj.GlobalBuf;
174+
MWeakGlobalOffset = SYCLObj.GlobalOffset;
175+
MWeakGlobalFlushBuf = SYCLObj.GlobalFlushBuf;
176+
return *this;
177+
}
178+
weak_object &operator=(const weak_object &Other) noexcept = default;
179+
weak_object &operator=(weak_object &&Other) noexcept = default;
180+
181+
void swap(weak_object &Other) noexcept {
182+
this->MObjWeakPtr.swap(Other.MObjWeakPtr);
183+
MWeakGlobalBuf.swap(Other.MWeakGlobalBuf);
184+
MWeakGlobalOffset.swap(Other.MWeakGlobalOffset);
185+
MWeakGlobalFlushBuf.swap(Other.MWeakGlobalFlushBuf);
186+
}
187+
188+
void reset() noexcept {
189+
this->MObjWeakPtr.reset();
190+
MWeakGlobalBuf.reset();
191+
MWeakGlobalOffset.reset();
192+
MWeakGlobalFlushBuf.reset();
193+
}
194+
195+
#ifndef __SYCL_DEVICE_ONLY__
196+
std::optional<stream> try_lock() const noexcept {
197+
auto ObjImplPtr = this->MObjWeakPtr.lock();
198+
auto GlobalBuf = MWeakGlobalBuf.try_lock();
199+
auto GlobalOffset = MWeakGlobalOffset.try_lock();
200+
auto GlobalFlushBuf = MWeakGlobalFlushBuf.try_lock();
201+
if (!ObjImplPtr || !GlobalBuf || !GlobalOffset || !GlobalFlushBuf)
202+
return std::nullopt;
203+
return stream{ObjImplPtr, *GlobalBuf, *GlobalOffset, *GlobalFlushBuf};
204+
}
205+
stream lock() const {
206+
std::optional<stream> OptionalObj = try_lock();
207+
if (!OptionalObj)
208+
throw sycl::exception(sycl::make_error_code(sycl::errc::invalid),
209+
"Referenced object has expired.");
210+
return *OptionalObj;
211+
}
212+
#else
213+
// On device calls to these functions are disallowed, so declare them but
214+
// don't define them to avoid compilation failures.
215+
std::optional<stream> try_lock() const noexcept;
216+
stream lock() const;
217+
#endif // __SYCL_DEVICE_ONLY__
218+
219+
private:
220+
// Additional members required for recreating stream.
221+
weak_object<detail::GlobalBufAccessorT> MWeakGlobalBuf;
222+
weak_object<detail::GlobalOffsetAccessorT> MWeakGlobalOffset;
223+
weak_object<detail::GlobalBufAccessorT> MWeakGlobalFlushBuf;
224+
};
225+
144226
} // namespace ext::oneapi
145227
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
146228
} // namespace sycl

sycl/include/sycl/stream.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,17 @@ inline __width_manipulator__ setw(int Width) {
743743
/// \ingroup sycl_api
744744
class __SYCL_EXPORT __SYCL_SPECIAL_CLASS __SYCL_TYPE(stream) stream
745745
: public detail::OwnerLessBase<stream> {
746+
private:
747+
#ifndef __SYCL_DEVICE_ONLY__
748+
// Constructor for recreating a stream.
749+
stream(std::shared_ptr<detail::stream_impl> Impl,
750+
detail::GlobalBufAccessorT GlobalBuf,
751+
detail::GlobalOffsetAccessorT GlobalOffset,
752+
detail::GlobalBufAccessorT GlobalFlushBuf)
753+
: impl{Impl}, GlobalBuf{GlobalBuf}, GlobalOffset{GlobalOffset},
754+
GlobalFlushBuf{GlobalFlushBuf} {}
755+
#endif
756+
746757
public:
747758
#ifdef __SYCL_DEVICE_ONLY__
748759
// Default constructor for objects later initialized with __init member.
@@ -811,6 +822,10 @@ class __SYCL_EXPORT __SYCL_SPECIAL_CLASS __SYCL_TYPE(stream) stream
811822
friend decltype(Obj::impl) detail::getSyclObjImpl(const Obj &SyclObject);
812823
#endif
813824

825+
// NOTE: Some members are required for reconstructing the stream, but are not
826+
// part of the implementation class. If more members are added, they should
827+
// also be added to the weak_object specialization for streams.
828+
814829
// Accessor to the global stream buffer. Global buffer contains all output
815830
// from the kernel.
816831
mutable detail::GlobalBufAccessorT GlobalBuf;
@@ -942,6 +957,8 @@ class __SYCL_EXPORT __SYCL_SPECIAL_CLASS __SYCL_TYPE(stream) stream
942957

943958
friend class handler;
944959

960+
template <typename SYCLObjT> friend class ext::oneapi::weak_object;
961+
945962
friend const stream &operator<<(const stream &, const char);
946963
friend const stream &operator<<(const stream &, const char *);
947964
template <typename ValueType>

sycl/test-e2e/WeakObject/weak_object_utils.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ template <template <typename> typename CallableT> void runTest(sycl::queue Q) {
3737
sycl::accessor<int, 3, sycl::access::mode::read_write,
3838
sycl::access::target::host_buffer>
3939
HAcc3D;
40+
sycl::host_accessor<int, 1> HAcc1D_2020;
41+
sycl::host_accessor<int, 2> HAcc2D_2020;
42+
sycl::host_accessor<int, 3> HAcc3D_2020;
4043

4144
CallableT<decltype(Plt)>()(Plt);
4245
CallableT<decltype(Dev)>()(Dev);
@@ -54,6 +57,9 @@ template <template <typename> typename CallableT> void runTest(sycl::queue Q) {
5457
CallableT<decltype(HAcc1D)>()(HAcc1D);
5558
CallableT<decltype(HAcc2D)>()(HAcc2D);
5659
CallableT<decltype(HAcc3D)>()(HAcc3D);
60+
CallableT<decltype(HAcc1D_2020)>()(HAcc1D_2020);
61+
CallableT<decltype(HAcc2D_2020)>()(HAcc2D_2020);
62+
CallableT<decltype(HAcc3D_2020)>()(HAcc3D_2020);
5763

5864
Q.submit([&](sycl::handler &CGH) {
5965
sycl::accessor DAcc1D{Buf1D, CGH, sycl::read_only};
@@ -62,13 +68,15 @@ template <template <typename> typename CallableT> void runTest(sycl::queue Q) {
6268
sycl::local_accessor<int> LAcc1D{1, CGH};
6369
sycl::local_accessor<int, 2> LAcc2D{sycl::range<2>{1, 2}, CGH};
6470
sycl::local_accessor<int, 3> LAcc3D{sycl::range<3>{1, 2, 3}, CGH};
71+
sycl::stream Stream{1024, 32, CGH};
6572

6673
CallableT<decltype(DAcc1D)>()(DAcc1D);
6774
CallableT<decltype(DAcc2D)>()(DAcc2D);
6875
CallableT<decltype(DAcc3D)>()(DAcc3D);
6976
CallableT<decltype(LAcc1D)>()(LAcc1D);
7077
CallableT<decltype(LAcc2D)>()(LAcc2D);
7178
CallableT<decltype(LAcc3D)>()(LAcc3D);
79+
CallableT<decltype(Stream)>()(Stream);
7280
});
7381
}
7482

@@ -120,6 +128,12 @@ void runTestMulti(sycl::queue Q1) {
120128
sycl::accessor<int, 3, sycl::access::mode::read_write,
121129
sycl::access::target::host_buffer>
122130
HAcc3D2;
131+
sycl::host_accessor<int, 1> HAcc1D1_2020;
132+
sycl::host_accessor<int, 2> HAcc2D1_2020;
133+
sycl::host_accessor<int, 3> HAcc3D1_2020;
134+
sycl::host_accessor<int, 1> HAcc1D2_2020;
135+
sycl::host_accessor<int, 2> HAcc2D2_2020;
136+
sycl::host_accessor<int, 3> HAcc3D2_2020;
123137

124138
CallableT<decltype(Ctx1)>()(Ctx1, Ctx2);
125139
CallableT<decltype(Q1)>()(Q1, Q2);
@@ -135,6 +149,9 @@ void runTestMulti(sycl::queue Q1) {
135149
CallableT<decltype(HAcc1D1)>()(HAcc1D1, HAcc1D2);
136150
CallableT<decltype(HAcc2D1)>()(HAcc2D1, HAcc2D2);
137151
CallableT<decltype(HAcc3D1)>()(HAcc3D1, HAcc3D2);
152+
CallableT<decltype(HAcc1D1_2020)>()(HAcc1D1_2020, HAcc1D2_2020);
153+
CallableT<decltype(HAcc2D1_2020)>()(HAcc2D1_2020, HAcc2D2_2020);
154+
CallableT<decltype(HAcc3D1_2020)>()(HAcc3D1_2020, HAcc3D2_2020);
138155

139156
Q1.submit([&](sycl::handler &CGH) {
140157
sycl::accessor DAcc1D1{Buf1D1, CGH, sycl::read_only};
@@ -149,12 +166,15 @@ void runTestMulti(sycl::queue Q1) {
149166
sycl::local_accessor<int, 2> LAcc2D2{sycl::range<2>{1, 2}, CGH};
150167
sycl::local_accessor<int, 3> LAcc3D1{sycl::range<3>{1, 2, 3}, CGH};
151168
sycl::local_accessor<int, 3> LAcc3D2{sycl::range<3>{1, 2, 3}, CGH};
169+
sycl::stream Stream1{1024, 32, CGH};
170+
sycl::stream Stream2{1024, 32, CGH};
152171

153172
CallableT<decltype(DAcc1D1)>()(DAcc1D1, DAcc1D2);
154173
CallableT<decltype(DAcc2D1)>()(DAcc2D1, DAcc2D2);
155174
CallableT<decltype(DAcc3D1)>()(DAcc3D1, DAcc3D2);
156175
CallableT<decltype(LAcc1D1)>()(LAcc1D1, LAcc1D2);
157176
CallableT<decltype(LAcc2D1)>()(LAcc2D1, LAcc2D2);
158177
CallableT<decltype(LAcc3D1)>()(LAcc3D1, LAcc3D2);
178+
CallableT<decltype(Stream1)>()(Stream1, Stream2);
159179
});
160180
}

0 commit comments

Comments
 (0)