Skip to content

Commit c73c9d0

Browse files
[NFC][SYCL] Raw context_impl in getInteropContext and queue_impl ctor (#19126)
Splitting into two PRs would result in unnecessary temporarily adjustments and merge conflicts later on between these changes, so perform in a single PR. They are both small enough anyway. Continuation of the refactoring in #18795 #18877 #18966 #18979 #18980 #18981 #19007 #19030 #19123
1 parent 3b007d9 commit c73c9d0

File tree

9 files changed

+37
-41
lines changed

9 files changed

+37
-41
lines changed

sycl/source/backend.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
126126
ur_device_handle_t UrDevice =
127127
Device ? getSyclObjImpl(*Device)->getHandleRef() : nullptr;
128128
const auto &Adapter = getAdapter(Backend);
129-
const auto &ContextImpl = getSyclObjImpl(Context);
129+
context_impl &ContextImpl = *getSyclObjImpl(Context);
130130

131131
if (PropList.has_property<ext::intel::property::queue::compute_index>()) {
132132
throw sycl::exception(
@@ -156,7 +156,7 @@ __SYCL_EXPORT queue make_queue(ur_native_handle_t NativeHandle,
156156
ur_queue_handle_t UrQueue = nullptr;
157157

158158
Adapter->call<UrApiKind::urQueueCreateWithNativeHandle>(
159-
NativeHandle, ContextImpl->getHandleRef(), UrDevice, &NativeProperties,
159+
NativeHandle, ContextImpl.getHandleRef(), UrDevice, &NativeProperties,
160160
&UrQueue);
161161
// Construct the SYCL queue from UR queue.
162162
return detail::createSyclObjFromImpl<queue>(

sycl/source/detail/graph_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ exec_graph_impl::exec_graph_impl(sycl::context Context,
996996
: MSchedule(), MGraphImpl(GraphImpl), MSyncPoints(),
997997
MQueueImpl(sycl::detail::queue_impl::create(
998998
*sycl::detail::getSyclObjImpl(GraphImpl->getDevice()),
999-
sycl::detail::getSyclObjImpl(Context), sycl::async_handler{},
999+
*sycl::detail::getSyclObjImpl(Context), sycl::async_handler{},
10001000
sycl::property_list{})),
10011001
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
10021002
MSchedulerDependencies(),

sycl/source/detail/queue_impl.hpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,11 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
117117
/// constructed.
118118
/// \param AsyncHandler is a SYCL asynchronous exception handler.
119119
/// \param PropList is a list of properties to use for queue construction.
120-
queue_impl(device_impl &Device, const ContextImplPtr &Context,
120+
queue_impl(device_impl &Device, std::shared_ptr<context_impl> &&Context,
121121
const async_handler &AsyncHandler, const property_list &PropList,
122122
private_tag)
123-
: MDevice(Device), MContext(Context), MAsyncHandler(AsyncHandler),
124-
MPropList(PropList),
123+
: MDevice(Device), MContext(std::move(Context)),
124+
MAsyncHandler(AsyncHandler), MPropList(PropList),
125125
MIsInorder(has_property<property::queue::in_order>()),
126126
MIsProfilingEnabled(has_property<property::queue::enable_profiling>()),
127127
MQueueID{
@@ -146,8 +146,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
146146
"Queue compute index must be a non-negative number less than "
147147
"device's number of available compute queue indices.");
148148
}
149-
if (!Context->isDeviceValid(Device)) {
150-
if (Context->getBackend() == backend::opencl)
149+
if (!MContext->isDeviceValid(Device)) {
150+
if (MContext->getBackend() == backend::opencl)
151151
throw sycl::exception(
152152
make_error_code(errc::invalid),
153153
"Queue cannot be constructed with the given context and device "
@@ -177,17 +177,13 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
177177
trySwitchingToNoEventsMode();
178178
}
179179

180-
sycl::detail::optional<event> getLastEvent();
180+
queue_impl(device_impl &Device, context_impl &Context,
181+
const async_handler &AsyncHandler, const property_list &PropList,
182+
private_tag Tag)
183+
: queue_impl(Device, Context.shared_from_this(), AsyncHandler, PropList,
184+
Tag) {}
181185

182-
/// Constructs a SYCL queue from adapter interoperability handle.
183-
///
184-
/// \param UrQueue is a raw UR queue handle.
185-
/// \param Context is a SYCL context to associate with the queue being
186-
/// constructed.
187-
/// \param AsyncHandler is a SYCL asynchronous exception handler.
188-
queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context,
189-
const async_handler &AsyncHandler, private_tag tag)
190-
: queue_impl(UrQueue, Context, AsyncHandler, {}, tag) {}
186+
sycl::detail::optional<event> getLastEvent();
191187

192188
/// Constructs a SYCL queue from adapter interoperability handle.
193189
///
@@ -196,27 +192,28 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
196192
/// constructed.
197193
/// \param AsyncHandler is a SYCL asynchronous exception handler.
198194
/// \param PropList is the queue properties.
199-
queue_impl(ur_queue_handle_t UrQueue, const ContextImplPtr &Context,
195+
queue_impl(ur_queue_handle_t UrQueue, context_impl &Context,
200196
const async_handler &AsyncHandler, const property_list &PropList,
201197
private_tag)
202198
: MDevice([&]() -> device_impl & {
203199
ur_device_handle_t DeviceUr{};
204-
const AdapterPtr &Adapter = Context->getAdapter();
200+
const AdapterPtr &Adapter = Context.getAdapter();
205201
// TODO catch an exception and put it to list of asynchronous
206202
// exceptions
207203
Adapter->call<UrApiKind::urQueueGetInfo>(
208204
UrQueue, UR_QUEUE_INFO_DEVICE, sizeof(DeviceUr), &DeviceUr,
209205
nullptr);
210-
device_impl *Device = Context->findMatchingDeviceImpl(DeviceUr);
206+
device_impl *Device = Context.findMatchingDeviceImpl(DeviceUr);
211207
if (Device == nullptr) {
212208
throw sycl::exception(
213209
make_error_code(errc::invalid),
214210
"Device provided by native Queue not found in Context.");
215211
}
216212
return *Device;
217213
}()),
218-
MContext(Context), MAsyncHandler(AsyncHandler), MPropList(PropList),
219-
MQueue(UrQueue), MIsInorder(has_property<property::queue::in_order>()),
214+
MContext(Context.shared_from_this()), MAsyncHandler(AsyncHandler),
215+
MPropList(PropList), MQueue(UrQueue),
216+
MIsInorder(has_property<property::queue::in_order>()),
220217
MIsProfilingEnabled(has_property<property::queue::enable_profiling>()),
221218
MQueueID{
222219
MNextAvailableQueueID.fetch_add(1, std::memory_order_relaxed)} {
@@ -988,7 +985,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
988985
mutable std::mutex MMutex;
989986

990987
device_impl &MDevice;
991-
const ContextImplPtr MContext;
988+
const std::shared_ptr<context_impl> MContext;
992989

993990
/// These events are tracked, but not owned, by the queue.
994991
std::vector<std::weak_ptr<event_impl>> MEventsWeak;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
210210
cleanupCommand(Cmd);
211211
};
212212

213-
const ContextImplPtr &InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
213+
context_impl *InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
214214
if (InteropCtxPtr) {
215215
// The memory object has been constructed using interoperability constructor
216216
// which means that there is already an allocation(cl_mem) in some context.
@@ -225,10 +225,10 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
225225
// here, we need to create a dummy queue bound to the context and one of the
226226
// devices from the context.
227227
std::shared_ptr<queue_impl> InteropQueuePtr = queue_impl::create(
228-
Dev, InteropCtxPtr, async_handler{}, property_list{});
228+
Dev, *InteropCtxPtr, async_handler{}, property_list{});
229229

230230
MemObject->MRecord.reset(
231-
new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency});
231+
new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency});
232232
std::vector<Command *> ToEnqueue;
233233
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr,
234234
ToEnqueue);

sycl/source/detail/sycl_mem_obj_i.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class context_impl;
2222
struct MemObjRecord;
2323

2424
using EventImplPtr = std::shared_ptr<detail::event_impl>;
25-
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
2625

2726
// The class serves as an interface in the scheduler for all SYCL memory
2827
// objects.
@@ -72,7 +71,7 @@ class SYCLMemObjI {
7271

7372
// Returns the context which is passed if a memory object is created using
7473
// interoperability constructor, nullptr otherwise.
75-
virtual ContextImplPtr getInteropContext() const = 0;
74+
virtual detail::context_impl *getInteropContext() const = 0;
7675

7776
protected:
7877
// Pointer to the record that contains the memory commands. This is managed

sycl/source/detail/sycl_mem_obj_t.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class event_impl;
3636
class Adapter;
3737
using AdapterPtr = std::shared_ptr<Adapter>;
3838

39-
using ContextImplPtr = std::shared_ptr<context_impl>;
4039
using EventImplPtr = std::shared_ptr<event_impl>;
4140

4241
// The class serves as a base for all SYCL memory objects.
@@ -281,7 +280,9 @@ class SYCLMemObjT : public SYCLMemObjI {
281280

282281
MemObjType getType() const override { return MemObjType::Undefined; }
283282

284-
ContextImplPtr getInteropContext() const override { return MInteropContext; }
283+
context_impl *getInteropContext() const override {
284+
return MInteropContext.get();
285+
}
285286

286287
bool isInterop() const override;
287288

@@ -339,7 +340,7 @@ class SYCLMemObjT : public SYCLMemObjI {
339340
// Should wait on this event before start working with such memory object.
340341
EventImplPtr MInteropEvent;
341342
// Context passed by user to interoperability constructor.
342-
ContextImplPtr MInteropContext;
343+
std::shared_ptr<context_impl> MInteropContext;
343344
// Native backend memory object handle passed by user to interoperability
344345
// constructor.
345346
ur_mem_handle_t MInteropMemObject;

sycl/source/queue.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ queue::queue(const context &SyclContext, const device_selector &DeviceSelector,
6565
const device &SyclDevice = *std::max_element(Devs.begin(), Devs.end(), Comp);
6666

6767
impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice),
68-
detail::getSyclObjImpl(SyclContext),
68+
*detail::getSyclObjImpl(SyclContext),
6969
AsyncHandler, PropList);
7070
}
7171

7272
queue::queue(const context &SyclContext, const device &SyclDevice,
7373
const async_handler &AsyncHandler, const property_list &PropList) {
7474
impl = detail::queue_impl::create(*detail::getSyclObjImpl(SyclDevice),
75-
detail::getSyclObjImpl(SyclContext),
75+
*detail::getSyclObjImpl(SyclContext),
7676
AsyncHandler, PropList);
7777
}
7878

@@ -100,7 +100,7 @@ queue::queue(cl_command_queue clQueue, const context &SyclContext,
100100
impl = detail::queue_impl::create(
101101
// TODO(pi2ur): Don't cast straight from cl_command_queue
102102
reinterpret_cast<ur_queue_handle_t>(clQueue),
103-
detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList);
103+
*detail::getSyclObjImpl(SyclContext), AsyncHandler, PropList);
104104
}
105105

106106
cl_command_queue queue::get() const { return impl->get(); }

sycl/unittests/scheduler/HostTaskAndBarrier.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
namespace {
2121
using namespace sycl;
2222
using EventImplPtr = std::shared_ptr<sycl::detail::event_impl>;
23-
using ContextImplPtr = std::shared_ptr<sycl::detail::context_impl>;
2423

2524
constexpr auto DisableCleanupName = "SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP";
2625

2726
class TestQueueImpl : public sycl::detail::queue_impl {
2827
public:
29-
TestQueueImpl(ContextImplPtr SyclContext, sycl::detail::device_impl &Dev)
28+
TestQueueImpl(sycl::detail::context_impl &SyclContext,
29+
sycl::detail::device_impl &Dev)
3030
: sycl::detail::queue_impl(Dev, SyclContext,
31-
SyclContext->get_async_handler(), {},
31+
SyclContext.get_async_handler(), {},
3232
sycl::detail::queue_impl::private_tag{}) {}
3333
using sycl::detail::queue_impl::MDefaultGraphDeps;
3434
using sycl::detail::queue_impl::MExtGraphDeps;
@@ -46,7 +46,7 @@ class BarrierHandlingWithHostTask : public ::testing::Test {
4646
sycl::device SyclDev =
4747
sycl::detail::select_device(sycl::default_selector_v, SyclContext);
4848
QueueDevImpl.reset(
49-
new TestQueueImpl(sycl::detail::getSyclObjImpl(SyclContext),
49+
new TestQueueImpl(*sycl::detail::getSyclObjImpl(SyclContext),
5050
*sycl::detail::getSyclObjImpl(SyclDev)));
5151

5252
MainLock.lock();

sycl/unittests/scheduler/LinkedAllocaDependencies.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ class MemObjMock : public sycl::detail::SYCLMemObjI {
3838
bool isHostPointerReadOnly() const override { return false; }
3939
bool usesPinnedHostMemory() const override { return false; }
4040

41-
std::shared_ptr<sycl::detail::context_impl>
42-
getInteropContext() const override {
41+
sycl::detail::context_impl *getInteropContext() const override {
4342
return nullptr;
4443
}
4544
};

0 commit comments

Comments
 (0)