Skip to content

Commit a697821

Browse files
[NFC][SYCL] Pass queue_impl by raw ptr in commands.hpp (#19004)
Continuation of the refactoring efforts in #18715 #18748 #18830 #18907 #18983
1 parent 6d49f27 commit a697821

17 files changed

+204
-189
lines changed

sycl/source/detail/cg.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -725,14 +725,10 @@ class CGHostTask : public CG {
725725
std::shared_ptr<detail::context_impl> MContext;
726726
std::vector<ArgDesc> MArgs;
727727

728-
CGHostTask(std::shared_ptr<HostTask> HostTask,
729-
std::shared_ptr<detail::queue_impl> Queue,
728+
CGHostTask(std::shared_ptr<HostTask> HostTask, detail::queue_impl *Queue,
730729
std::shared_ptr<detail::context_impl> Context,
731730
std::vector<ArgDesc> Args, CG::StorageInitHelper CGData,
732-
CGType Type, detail::code_location loc = {})
733-
: CG(Type, std::move(CGData), std::move(loc)),
734-
MHostTask(std::move(HostTask)), MQueue(Queue), MContext(Context),
735-
MArgs(std::move(Args)) {}
731+
CGType Type, detail::code_location loc = {});
736732
};
737733

738734
} // namespace detail

sycl/source/detail/graph_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
288288

289289
return std::make_unique<sycl::detail::CGHostTask>(
290290
sycl::detail::CGHostTask(
291-
std::move(HostTaskSPtr), CommandGroupPtr->MQueue,
291+
std::move(HostTaskSPtr), CommandGroupPtr->MQueue.get(),
292292
CommandGroupPtr->MContext, std::move(NewArgs), std::move(Data),
293293
CommandGroupPtr->getType(), Loc));
294294
}

sycl/source/detail/queue_impl.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,12 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
650650
// for in order ones.
651651
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);
652652

653-
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
653+
static ContextImplPtr getContext(queue_impl *Queue) {
654654
return Queue ? Queue->getContextImplPtr() : nullptr;
655655
}
656+
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
657+
return getContext(Queue.get());
658+
}
656659

657660
// Must be called under MMutex protection
658661
void doUnenqueuedCommandCleanup(

sycl/source/detail/scheduler/commands.cpp

Lines changed: 62 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,14 @@ static unsigned long long getQueueID(const std::shared_ptr<queue_impl> &Queue) {
127127
}
128128
#endif
129129

130-
static context_impl *getContext(const QueueImplPtr &Queue) {
130+
static context_impl *getContext(queue_impl *Queue) {
131131
if (Queue)
132132
return &Queue->getContextImpl();
133133
return nullptr;
134134
}
135+
static context_impl *getContext(const std::shared_ptr<queue_impl> &Queue) {
136+
return getContext(Queue.get());
137+
}
135138

136139
#ifdef __SYCL_ENABLE_GNU_DEMANGLING
137140
struct DemangleHandle {
@@ -503,7 +506,7 @@ void Command::waitForPreparedHostEvents() const {
503506
HostEvent->waitInternal();
504507
}
505508

506-
void Command::waitForEvents(QueueImplPtr Queue,
509+
void Command::waitForEvents(queue_impl *Queue,
507510
std::vector<EventImplPtr> &EventImpls,
508511
ur_event_handle_t &Event) {
509512
#ifndef NDEBUG
@@ -559,12 +562,12 @@ void Command::waitForEvents(QueueImplPtr Queue,
559562
/// references to event_impl class members because Command
560563
/// should not outlive the event connected to it.
561564
Command::Command(
562-
CommandType Type, QueueImplPtr Queue,
565+
CommandType Type, queue_impl *Queue,
563566
ur_exp_command_buffer_handle_t CommandBuffer,
564567
const std::vector<ur_exp_command_buffer_sync_point_t> &SyncPoints)
565-
: MQueue(std::move(Queue)),
566-
MEvent(MQueue ? detail::event_impl::create_device_event(*MQueue)
567-
: detail::event_impl::create_incomplete_host_event()),
568+
: MQueue(Queue ? Queue->shared_from_this() : nullptr),
569+
MEvent(Queue ? detail::event_impl::create_device_event(*Queue)
570+
: detail::event_impl::create_incomplete_host_event()),
568571
MPreparedDepsEvents(MEvent->getPreparedDepsEvents()),
569572
MPreparedHostDepsEvents(MEvent->getPreparedHostDepsEvents()), MType(Type),
570573
MCommandBuffer(CommandBuffer), MSyncPointDeps(SyncPoints) {
@@ -1027,7 +1030,7 @@ void Command::copySubmissionCodeLocation() {
10271030
#endif
10281031
}
10291032

1030-
AllocaCommandBase::AllocaCommandBase(CommandType Type, QueueImplPtr Queue,
1033+
AllocaCommandBase::AllocaCommandBase(CommandType Type, queue_impl *Queue,
10311034
Requirement Req,
10321035
AllocaCommandBase *LinkedAllocaCmd,
10331036
bool IsConst)
@@ -1070,10 +1073,10 @@ bool AllocaCommandBase::supportsPostEnqueueCleanup() const { return false; }
10701073

10711074
bool AllocaCommandBase::readyForCleanup() const { return false; }
10721075

1073-
AllocaCommand::AllocaCommand(QueueImplPtr Queue, Requirement Req,
1076+
AllocaCommand::AllocaCommand(queue_impl *Queue, Requirement Req,
10741077
bool InitFromUserData,
10751078
AllocaCommandBase *LinkedAllocaCmd, bool IsConst)
1076-
: AllocaCommandBase(CommandType::ALLOCA, std::move(Queue), std::move(Req),
1079+
: AllocaCommandBase(CommandType::ALLOCA, Queue, std::move(Req),
10771080
LinkedAllocaCmd, IsConst),
10781081
MInitFromUserData(InitFromUserData) {
10791082
// Node event must be created before the dependent edge is added to this
@@ -1108,7 +1111,7 @@ ur_result_t AllocaCommand::enqueueImp() {
11081111

11091112
if (!MQueue) {
11101113
// Do not need to make allocation if we have a linked device allocation
1111-
Command::waitForEvents(MQueue, EventImpls, UREvent);
1114+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
11121115
MEvent->setHandle(UREvent);
11131116

11141117
return UR_RESULT_SUCCESS;
@@ -1148,12 +1151,11 @@ void AllocaCommand::printDot(std::ostream &Stream) const {
11481151
}
11491152
}
11501153

1151-
AllocaSubBufCommand::AllocaSubBufCommand(QueueImplPtr Queue, Requirement Req,
1154+
AllocaSubBufCommand::AllocaSubBufCommand(queue_impl *Queue, Requirement Req,
11521155
AllocaCommandBase *ParentAlloca,
11531156
std::vector<Command *> &ToEnqueue,
11541157
std::vector<Command *> &ToCleanUp)
1155-
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, std::move(Queue),
1156-
std::move(Req),
1158+
: AllocaCommandBase(CommandType::ALLOCA_SUB_BUF, Queue, std::move(Req),
11571159
/*LinkedAllocaCmd*/ nullptr, /*IsConst*/ false),
11581160
MParentAlloca(ParentAlloca) {
11591161
// Node event must be created before the dependent edge
@@ -1234,8 +1236,8 @@ void AllocaSubBufCommand::printDot(std::ostream &Stream) const {
12341236
}
12351237
}
12361238

1237-
ReleaseCommand::ReleaseCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd)
1238-
: Command(CommandType::RELEASE, std::move(Queue)), MAllocaCmd(AllocaCmd) {
1239+
ReleaseCommand::ReleaseCommand(queue_impl *Queue, AllocaCommandBase *AllocaCmd)
1240+
: Command(CommandType::RELEASE, Queue), MAllocaCmd(AllocaCmd) {
12391241
emitInstrumentationDataProxy();
12401242
}
12411243

@@ -1288,9 +1290,9 @@ ur_result_t ReleaseCommand::enqueueImp() {
12881290
}
12891291

12901292
if (NeedUnmap) {
1291-
const QueueImplPtr &Queue = CurAllocaIsHost
1292-
? MAllocaCmd->MLinkedAllocaCmd->getQueue()
1293-
: MAllocaCmd->getQueue();
1293+
queue_impl *Queue = CurAllocaIsHost
1294+
? MAllocaCmd->MLinkedAllocaCmd->getQueue()
1295+
: MAllocaCmd->getQueue();
12941296

12951297
assert(Queue);
12961298

@@ -1321,7 +1323,7 @@ ur_result_t ReleaseCommand::enqueueImp() {
13211323
}
13221324
ur_event_handle_t UREvent = nullptr;
13231325
if (SkipRelease)
1324-
Command::waitForEvents(MQueue, EventImpls, UREvent);
1326+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
13251327
else {
13261328
if (auto Result = callMemOpHelper(
13271329
MemoryManager::release, getContext(MQueue),
@@ -1359,11 +1361,10 @@ bool ReleaseCommand::supportsPostEnqueueCleanup() const { return false; }
13591361
bool ReleaseCommand::readyForCleanup() const { return false; }
13601362

13611363
MapMemObject::MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req,
1362-
void **DstPtr, QueueImplPtr Queue,
1364+
void **DstPtr, queue_impl *Queue,
13631365
access::mode MapMode)
1364-
: Command(CommandType::MAP_MEM_OBJ, std::move(Queue)),
1365-
MSrcAllocaCmd(SrcAllocaCmd), MSrcReq(std::move(Req)), MDstPtr(DstPtr),
1366-
MMapMode(MapMode) {
1366+
: Command(CommandType::MAP_MEM_OBJ, Queue), MSrcAllocaCmd(SrcAllocaCmd),
1367+
MSrcReq(std::move(Req)), MDstPtr(DstPtr), MMapMode(MapMode) {
13671368
emitInstrumentationDataProxy();
13681369
}
13691370

@@ -1423,9 +1424,9 @@ void MapMemObject::printDot(std::ostream &Stream) const {
14231424
}
14241425

14251426
UnMapMemObject::UnMapMemObject(AllocaCommandBase *DstAllocaCmd, Requirement Req,
1426-
void **SrcPtr, QueueImplPtr Queue)
1427-
: Command(CommandType::UNMAP_MEM_OBJ, std::move(Queue)),
1428-
MDstAllocaCmd(DstAllocaCmd), MDstReq(std::move(Req)), MSrcPtr(SrcPtr) {
1427+
void **SrcPtr, queue_impl *Queue)
1428+
: Command(CommandType::UNMAP_MEM_OBJ, Queue), MDstAllocaCmd(DstAllocaCmd),
1429+
MDstReq(std::move(Req)), MSrcPtr(SrcPtr) {
14291430
emitInstrumentationDataProxy();
14301431
}
14311432

@@ -1509,11 +1510,11 @@ MemCpyCommand::MemCpyCommand(Requirement SrcReq,
15091510
AllocaCommandBase *SrcAllocaCmd,
15101511
Requirement DstReq,
15111512
AllocaCommandBase *DstAllocaCmd,
1512-
QueueImplPtr SrcQueue, QueueImplPtr DstQueue)
1513-
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
1514-
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
1515-
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)),
1516-
MDstAllocaCmd(DstAllocaCmd) {
1513+
queue_impl *SrcQueue, queue_impl *DstQueue)
1514+
: Command(CommandType::COPY_MEMORY, DstQueue),
1515+
MSrcQueue(SrcQueue ? SrcQueue->shared_from_this() : nullptr),
1516+
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
1517+
MDstReq(std::move(DstReq)), MDstAllocaCmd(DstAllocaCmd) {
15171518
if (MSrcQueue) {
15181519
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
15191520
}
@@ -1645,7 +1646,7 @@ ur_result_t UpdateHostRequirementCommand::enqueueImp() {
16451646
waitForPreparedHostEvents();
16461647
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
16471648
ur_event_handle_t UREvent = nullptr;
1648-
Command::waitForEvents(MQueue, EventImpls, UREvent);
1649+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
16491650
MEvent->setHandle(UREvent);
16501651

16511652
assert(MSrcAllocaCmd && "Expected valid alloca command");
@@ -1682,11 +1683,11 @@ void UpdateHostRequirementCommand::printDot(std::ostream &Stream) const {
16821683
MemCpyCommandHost::MemCpyCommandHost(Requirement SrcReq,
16831684
AllocaCommandBase *SrcAllocaCmd,
16841685
Requirement DstReq, void **DstPtr,
1685-
QueueImplPtr SrcQueue,
1686-
QueueImplPtr DstQueue)
1687-
: Command(CommandType::COPY_MEMORY, std::move(DstQueue)),
1688-
MSrcQueue(SrcQueue), MSrcReq(std::move(SrcReq)),
1689-
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
1686+
queue_impl *SrcQueue, queue_impl *DstQueue)
1687+
: Command(CommandType::COPY_MEMORY, DstQueue),
1688+
MSrcQueue(SrcQueue ? SrcQueue->shared_from_this() : nullptr),
1689+
MSrcReq(std::move(SrcReq)), MSrcAllocaCmd(SrcAllocaCmd),
1690+
MDstReq(std::move(DstReq)), MDstPtr(DstPtr) {
16901691
if (MSrcQueue) {
16911692
MEvent->setContextImpl(MSrcQueue->getContextImplPtr());
16921693
}
@@ -1728,7 +1729,7 @@ ContextImplPtr MemCpyCommandHost::getWorkerContext() const {
17281729
}
17291730

17301731
ur_result_t MemCpyCommandHost::enqueueImp() {
1731-
const QueueImplPtr &Queue = MWorkerQueue;
1732+
queue_impl *Queue = MWorkerQueue.get();
17321733
waitForPreparedHostEvents();
17331734
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
17341735
std::vector<ur_event_handle_t> RawEvents = getUrEvents(EventImpls);
@@ -1767,7 +1768,7 @@ EmptyCommand::EmptyCommand() : Command(CommandType::EMPTY_TASK, nullptr) {
17671768
ur_result_t EmptyCommand::enqueueImp() {
17681769
waitForPreparedHostEvents();
17691770
ur_event_handle_t UREvent = nullptr;
1770-
waitForEvents(MQueue, MPreparedDepsEvents, UREvent);
1771+
waitForEvents(MQueue.get(), MPreparedDepsEvents, UREvent);
17711772
MEvent->setHandle(UREvent);
17721773
return UR_RESULT_SUCCESS;
17731774
}
@@ -1851,9 +1852,9 @@ void MemCpyCommandHost::printDot(std::ostream &Stream) const {
18511852
}
18521853

18531854
UpdateHostRequirementCommand::UpdateHostRequirementCommand(
1854-
QueueImplPtr Queue, Requirement Req, AllocaCommandBase *SrcAllocaCmd,
1855+
queue_impl *Queue, Requirement Req, AllocaCommandBase *SrcAllocaCmd,
18551856
void **DstPtr)
1856-
: Command(CommandType::UPDATE_REQUIREMENT, std::move(Queue)),
1857+
: Command(CommandType::UPDATE_REQUIREMENT, Queue),
18571858
MSrcAllocaCmd(SrcAllocaCmd), MDstReq(std::move(Req)), MDstPtr(DstPtr) {
18581859

18591860
emitInstrumentationDataProxy();
@@ -1949,11 +1950,10 @@ static std::string_view cgTypeToString(detail::CGType Type) {
19491950
}
19501951

19511952
ExecCGCommand::ExecCGCommand(
1952-
std::unique_ptr<detail::CG> CommandGroup, QueueImplPtr Queue,
1953+
std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
19531954
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer,
19541955
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies)
1955-
: Command(CommandType::RUN_CG, std::move(Queue), CommandBuffer,
1956-
Dependencies),
1956+
: Command(CommandType::RUN_CG, Queue, CommandBuffer, Dependencies),
19571957
MEventNeeded(EventNeeded), MCommandGroup(std::move(CommandGroup)) {
19581958
if (MCommandGroup->getType() == detail::CGType::CodeplayHostTask) {
19591959
MEvent->setSubmittedQueue(
@@ -2770,20 +2770,18 @@ void enqueueImpKernel(
27702770
}
27712771
}
27722772

2773-
ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue,
2773+
ur_result_t enqueueReadWriteHostPipe(queue_impl &Queue,
27742774
const std::string &PipeName, bool blocking,
27752775
void *ptr, size_t size,
27762776
std::vector<ur_event_handle_t> &RawEvents,
27772777
detail::event_impl *OutEventImpl,
27782778
bool read) {
2779-
assert(Queue &&
2780-
"ReadWrite host pipe submissions should have an associated queue");
27812779
detail::HostPipeMapEntry *hostPipeEntry =
27822780
ProgramManager::getInstance().getHostPipeEntry(PipeName);
27832781

27842782
ur_program_handle_t Program = nullptr;
2785-
device Device = Queue->get_device();
2786-
ContextImplPtr ContextImpl = Queue->getContextImplPtr();
2783+
device Device = Queue.get_device();
2784+
ContextImplPtr ContextImpl = Queue.getContextImplPtr();
27872785
std::optional<ur_program_handle_t> CachedProgram =
27882786
ContextImpl->getProgramForHostPipe(Device, hostPipeEntry);
27892787
if (CachedProgram)
@@ -2792,17 +2790,16 @@ ur_result_t enqueueReadWriteHostPipe(const QueueImplPtr &Queue,
27922790
// If there was no cached program, build one.
27932791
device_image_plain devImgPlain =
27942792
ProgramManager::getInstance().getDeviceImageFromBinaryImage(
2795-
hostPipeEntry->getDevBinImage(), Queue->get_context(),
2796-
Queue->get_device());
2793+
hostPipeEntry->getDevBinImage(), Queue.get_context(), Device);
27972794
device_image_plain BuiltImage = ProgramManager::getInstance().build(
27982795
std::move(devImgPlain), {std::move(Device)}, {});
27992796
Program = getSyclObjImpl(BuiltImage)->get_ur_program_ref();
28002797
}
28012798
assert(Program && "Program for this hostpipe is not compiled.");
28022799

2803-
const AdapterPtr &Adapter = Queue->getAdapter();
2800+
const AdapterPtr &Adapter = Queue.getAdapter();
28042801

2805-
ur_queue_handle_t ur_q = Queue->getHandleRef();
2802+
ur_queue_handle_t ur_q = Queue.getHandleRef();
28062803
ur_result_t Error;
28072804

28082805
ur_event_handle_t UREvent = nullptr;
@@ -3660,7 +3657,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
36603657
if (!EventImpl) {
36613658
EventImpl = MEvent.get();
36623659
}
3663-
return enqueueReadWriteHostPipe(MQueue, pipeName, blocking, hostPtr,
3660+
return enqueueReadWriteHostPipe(*MQueue, pipeName, blocking, hostPtr,
36643661
typeSize, RawEvents, EventImpl, read);
36653662
}
36663663
case CGType::ExecCommandBuffer: {
@@ -3795,7 +3792,7 @@ bool ExecCGCommand::readyForCleanup() const {
37953792
}
37963793

37973794
UpdateCommandBufferCommand::UpdateCommandBufferCommand(
3798-
QueueImplPtr Queue,
3795+
queue_impl *Queue,
37993796
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
38003797
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
38013798
Nodes)
@@ -3806,7 +3803,7 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() {
38063803
waitForPreparedHostEvents();
38073804
std::vector<EventImplPtr> EventImpls = MPreparedDepsEvents;
38083805
ur_event_handle_t UREvent = nullptr;
3809-
Command::waitForEvents(MQueue, EventImpls, UREvent);
3806+
Command::waitForEvents(MQueue.get(), EventImpls, UREvent);
38103807
MEvent->setHandle(UREvent);
38113808

38123809
auto CheckAndFindAlloca = [](Requirement *Req, const DepDesc &Dep) {
@@ -3878,6 +3875,15 @@ void UpdateCommandBufferCommand::printDot(std::ostream &Stream) const {
38783875
void UpdateCommandBufferCommand::emitInstrumentationData() {}
38793876
bool UpdateCommandBufferCommand::producesPiEvent() const { return false; }
38803877

3878+
CGHostTask::CGHostTask(std::shared_ptr<HostTask> HostTask,
3879+
detail::queue_impl *Queue,
3880+
std::shared_ptr<detail::context_impl> Context,
3881+
std::vector<ArgDesc> Args, CG::StorageInitHelper CGData,
3882+
CGType Type, detail::code_location loc)
3883+
: CG(Type, std::move(CGData), std::move(loc)),
3884+
MHostTask(std::move(HostTask)),
3885+
MQueue(Queue ? Queue->shared_from_this() : nullptr), MContext(Context),
3886+
MArgs(std::move(Args)) {}
38813887
} // namespace detail
38823888
} // namespace _V1
38833889
} // namespace sycl

0 commit comments

Comments
 (0)