Skip to content

Commit adb90af

Browse files
[NFC][SYCL] Pass queue_impl by raw ptr/ref (mostly scheduler) (#19120)
Continuation of the refactoring efforts in #18715 #18748 #18830 #18907 #18983 #19006
1 parent ab1d89f commit adb90af

24 files changed

+154
-177
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
866866

867867
sycl::detail::EventImplPtr Event =
868868
sycl::detail::Scheduler::getInstance().addCG(
869-
Node->getCGCopy(), MQueueImpl,
869+
Node->getCGCopy(), *MQueueImpl,
870870
/*EventNeeded=*/true, CommandBuffer, Deps);
871871

872872
if (MIsUpdatable) {
@@ -1048,7 +1048,7 @@ EventImplPtr exec_graph_impl::enqueueHostTaskPartition(
10481048
NodeCommandGroup->getType()));
10491049

10501050
EventImplPtr SchedulerEvent = sycl::detail::Scheduler::getInstance().addCG(
1051-
std::move(CommandGroup), Queue.shared_from_this(), EventNeeded);
1051+
std::move(CommandGroup), Queue, EventNeeded);
10521052

10531053
if (EventNeeded) {
10541054
return SchedulerEvent;
@@ -1076,7 +1076,7 @@ EventImplPtr exec_graph_impl::enqueuePartitionWithScheduler(
10761076
CommandBuffer, nullptr, std::move(CGData));
10771077

10781078
EventImplPtr SchedulerEvent = sycl::detail::Scheduler::getInstance().addCG(
1079-
std::move(CommandGroup), Queue.shared_from_this(), EventNeeded);
1079+
std::move(CommandGroup), Queue, EventNeeded);
10801080

10811081
if (EventNeeded) {
10821082
SchedulerEvent->setEventFromSubmittedExecCommandBuffer(true);
@@ -1551,7 +1551,7 @@ void exec_graph_impl::update(
15511551
// other scheduler commands
15521552
auto UpdateEvent =
15531553
sycl::detail::Scheduler::getInstance().addCommandGraphUpdate(
1554-
this, Nodes, MQueueImpl, std::move(UpdateRequirements),
1554+
this, Nodes, MQueueImpl.get(), std::move(UpdateRequirements),
15551555
MSchedulerDependencies);
15561556

15571557
MSchedulerDependencies.push_back(UpdateEvent);

sycl/source/detail/queue_impl.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ queue_impl::get_backend_info<info::device::backend_version>() const {
118118
}
119119
#endif
120120

121-
static event prepareSYCLEventAssociatedWithQueue(
122-
const std::shared_ptr<detail::queue_impl> &QueueImpl) {
123-
auto EventImpl = detail::event_impl::create_device_event(*QueueImpl);
124-
EventImpl->setContextImpl(QueueImpl->getContextImpl());
121+
static event
122+
prepareSYCLEventAssociatedWithQueue(detail::queue_impl &QueueImpl) {
123+
auto EventImpl = detail::event_impl::create_device_event(QueueImpl);
124+
EventImpl->setContextImpl(QueueImpl.getContextImpl());
125125
EventImpl->setStateIncomplete();
126126
return detail::createSyclObjFromImpl<event>(EventImpl);
127127
}
@@ -464,7 +464,7 @@ event queue_impl::submitMemOpHelper(const std::vector<event> &DepEvents,
464464
event_impl::create_discarded_event());
465465
}
466466

467-
event ResEvent = prepareSYCLEventAssociatedWithQueue(shared_from_this());
467+
event ResEvent = prepareSYCLEventAssociatedWithQueue(*this);
468468
const auto &EventImpl = detail::getSyclObjImpl(ResEvent);
469469
{
470470
NestedCallsTracker tracker;

sycl/source/detail/queue_impl.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,6 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
649649
static ContextImplPtr getContext(queue_impl *Queue) {
650650
return Queue ? Queue->getContextImplPtr() : nullptr;
651651
}
652-
static ContextImplPtr getContext(const QueueImplPtr &Queue) {
653-
return getContext(Queue.get());
654-
}
655652

656653
// Must be called under MMutex protection
657654
void doUnenqueuedCommandCleanup(
@@ -688,7 +685,7 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
688685
protected:
689686
template <typename HandlerType = handler>
690687
EventImplPtr insertHelperBarrier(const HandlerType &Handler) {
691-
auto &Queue = Handler.impl->get_queue();
688+
queue_impl &Queue = Handler.impl->get_queue();
692689
auto ResEvent = detail::event_impl::create_device_event(Queue);
693690
ur_event_handle_t UREvent = nullptr;
694691
getAdapter()->call<UrApiKind::urEnqueueEventsWaitWithBarrier>(

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ static bool isOnSameContext(const ContextImplPtr Context, queue_impl *Queue) {
5757
// contexts comparison.
5858
return Context == queue_impl::getContext(Queue);
5959
}
60-
static bool isOnSameContext(const ContextImplPtr Context,
61-
const QueueImplPtr &Queue) {
62-
return isOnSameContext(Context, Queue.get());
63-
}
6460

6561
/// Checks if the required access mode is allowed under the current one.
6662
static bool isAccessModeAllowed(access::mode Required, access::mode Current) {
@@ -183,7 +179,7 @@ MemObjRecord *Scheduler::GraphBuilder::getMemObjRecord(SYCLMemObjI *MemObject) {
183179
}
184180

185181
MemObjRecord *
186-
Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
182+
Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
187183
const Requirement *Req) {
188184
SYCLMemObjI *MemObject = Req->MSYCLMemObj;
189185
MemObjRecord *Record = getMemObjRecord(MemObject);
@@ -231,8 +227,8 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
231227
MemObject->MRecord.reset(
232228
new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency});
233229
std::vector<Command *> ToEnqueue;
234-
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr,
235-
ToEnqueue);
230+
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req,
231+
InteropQueuePtr.get(), ToEnqueue);
236232
assert(ToEnqueue.empty() && "Creation of the first alloca for a record "
237233
"shouldn't lead to any enqueuing (no linked "
238234
"alloca or exceeding the leaf limit).");
@@ -274,14 +270,13 @@ void Scheduler::GraphBuilder::addNodeToLeaves(
274270
}
275271

276272
UpdateHostRequirementCommand *Scheduler::GraphBuilder::insertUpdateHostReqCmd(
277-
MemObjRecord *Record, Requirement *Req, const QueueImplPtr &Queue,
273+
MemObjRecord *Record, Requirement *Req, queue_impl *Queue,
278274
std::vector<Command *> &ToEnqueue) {
279275
auto Context = queue_impl::getContext(Queue);
280276
AllocaCommandBase *AllocaCmd = findAllocaForReq(Record, Req, Context);
281277
assert(AllocaCmd && "There must be alloca for requirement!");
282278
UpdateHostRequirementCommand *UpdateCommand =
283-
new UpdateHostRequirementCommand(Queue.get(), *Req, AllocaCmd,
284-
&Req->MData);
279+
new UpdateHostRequirementCommand(Queue, *Req, AllocaCmd, &Req->MData);
285280
// Need copy of requirement because after host accessor destructor call
286281
// dependencies become invalid if requirement is stored by pointer.
287282
const Requirement *StoredReq = UpdateCommand->getRequirement();
@@ -330,9 +325,10 @@ static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc,
330325
return MapCmd;
331326
}
332327

333-
Command *Scheduler::GraphBuilder::insertMemoryMove(
334-
MemObjRecord *Record, Requirement *Req, const QueueImplPtr &Queue,
335-
std::vector<Command *> &ToEnqueue) {
328+
Command *
329+
Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
330+
Requirement *Req, queue_impl *Queue,
331+
std::vector<Command *> &ToEnqueue) {
336332
AllocaCommandBase *AllocaCmdDst =
337333
getOrCreateAllocaForReq(Record, Req, Queue, ToEnqueue);
338334
if (!AllocaCmdDst)
@@ -519,7 +515,7 @@ Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
519515
auto SYCLMemObj = static_cast<detail::SYCLMemObjT *>(Req->MSYCLMemObj);
520516
SYCLMemObj->handleWriteAccessorCreation();
521517
}
522-
// Host accessor is not attached to any queue so no QueueImplPtr object to be
518+
// Host accessor is not attached to any queue so no queue object to be
523519
// sent to getOrInsertMemObjRecord.
524520
MemObjRecord *Record = getOrInsertMemObjRecord(nullptr, Req);
525521
if (MPrintOptionsArray[BeforeAddHostAcc])
@@ -691,7 +687,7 @@ static bool checkHostUnifiedMemory(const ContextImplPtr &Ctx) {
691687
// Note, creation of new allocation command can lead to the current context
692688
// (Record->MCurContext) change.
693689
AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
694-
MemObjRecord *Record, const Requirement *Req, const QueueImplPtr &Queue,
690+
MemObjRecord *Record, const Requirement *Req, queue_impl *Queue,
695691
std::vector<Command *> &ToEnqueue) {
696692
auto Context = queue_impl::getContext(Queue);
697693
AllocaCommandBase *AllocaCmd =
@@ -710,8 +706,8 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
710706

711707
auto *ParentAlloca =
712708
getOrCreateAllocaForReq(Record, &ParentRequirement, Queue, ToEnqueue);
713-
AllocaCmd = new AllocaSubBufCommand(Queue.get(), *Req, ParentAlloca,
714-
ToEnqueue, ToCleanUp);
709+
AllocaCmd = new AllocaSubBufCommand(Queue, *Req, ParentAlloca, ToEnqueue,
710+
ToCleanUp);
715711
} else {
716712

717713
const Requirement FullReq(/*Offset*/ {0, 0, 0}, Req->MMemoryRange,
@@ -787,8 +783,8 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
787783
}
788784
}
789785

790-
AllocaCmd = new AllocaCommand(Queue.get(), FullReq, InitFromUserData,
791-
LinkedAllocaCmd);
786+
AllocaCmd =
787+
new AllocaCommand(Queue, FullReq, InitFromUserData, LinkedAllocaCmd);
792788

793789
// Update linked command
794790
if (LinkedAllocaCmd) {
@@ -926,16 +922,16 @@ static void combineAccessModesOfReqs(std::vector<Requirement *> &Reqs) {
926922
}
927923

928924
Command *Scheduler::GraphBuilder::addCG(
929-
std::unique_ptr<detail::CG> CommandGroup, const QueueImplPtr &Queue,
925+
std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
930926
std::vector<Command *> &ToEnqueue, bool EventNeeded,
931927
ur_exp_command_buffer_handle_t CommandBuffer,
932928
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies) {
933929
std::vector<Requirement *> &Reqs = CommandGroup->getRequirements();
934930
std::vector<detail::EventImplPtr> &Events = CommandGroup->getEvents();
935931

936-
auto NewCmd = std::make_unique<ExecCGCommand>(
937-
std::move(CommandGroup), Queue.get(), EventNeeded, CommandBuffer,
938-
std::move(Dependencies));
932+
auto NewCmd = std::make_unique<ExecCGCommand>(std::move(CommandGroup), Queue,
933+
EventNeeded, CommandBuffer,
934+
std::move(Dependencies));
939935

940936
if (!NewCmd)
941937
throw exception(make_error_code(errc::memory_allocation),
@@ -958,9 +954,9 @@ Command *Scheduler::GraphBuilder::addCG(
958954
bool isSameCtx = false;
959955

960956
{
961-
const QueueImplPtr &QueueForAlloca =
957+
queue_impl *QueueForAlloca =
962958
isInteropTask
963-
? static_cast<detail::CGHostTask &>(NewCmd->getCG()).MQueue
959+
? static_cast<detail::CGHostTask &>(NewCmd->getCG()).MQueue.get()
964960
: Queue;
965961

966962
Record = getOrInsertMemObjRecord(QueueForAlloca, Req);
@@ -990,15 +986,15 @@ Command *Scheduler::GraphBuilder::addCG(
990986
// Cannot directly copy memory from OpenCL device to OpenCL device -
991987
// create two copies: device->host and host->device.
992988
bool NeedMemMoveToHost = false;
993-
auto MemMoveTargetQueue = Queue;
989+
queue_impl *MemMoveTargetQueue = Queue;
994990

995991
if (isInteropTask) {
996992
const detail::CGHostTask &HT =
997993
static_cast<detail::CGHostTask &>(NewCmd->getCG());
998994

999-
if (!isOnSameContext(Record->MCurContext, HT.MQueue)) {
995+
if (!isOnSameContext(Record->MCurContext, HT.MQueue.get())) {
1000996
NeedMemMoveToHost = true;
1001-
MemMoveTargetQueue = HT.MQueue;
997+
MemMoveTargetQueue = HT.MQueue.get();
1002998
}
1003999
} else if (Queue && Record->MCurContext)
10041000
NeedMemMoveToHost = true;
@@ -1230,7 +1226,9 @@ Command *Scheduler::GraphBuilder::connectDepEvent(
12301226
try {
12311227
std::shared_ptr<detail::HostTask> HT(new detail::HostTask);
12321228
std::unique_ptr<detail::CG> ConnectCG(new detail::CGHostTask(
1233-
std::move(HT), /* Queue = */ Cmd->getQueue(), /* Context = */ {},
1229+
std::move(HT),
1230+
/* Queue = */ Cmd->getQueue(),
1231+
/* Context = */ {},
12341232
/* Args = */ {},
12351233
detail::CG::StorageInitHelper(
12361234
/* ArgsStorage = */ {}, /* AccStorage = */ {},
@@ -1281,11 +1279,11 @@ Command *Scheduler::GraphBuilder::addCommandGraphUpdate(
12811279
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
12821280
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
12831281
Nodes,
1284-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
1282+
queue_impl *Queue, std::vector<Requirement *> Requirements,
12851283
std::vector<detail::EventImplPtr> &Events,
12861284
std::vector<Command *> &ToEnqueue) {
12871285
auto NewCmd =
1288-
std::make_unique<UpdateCommandBufferCommand>(Queue.get(), Graph, Nodes);
1286+
std::make_unique<UpdateCommandBufferCommand>(Queue, Graph, Nodes);
12891287
// If there are multiple requirements for the same memory object, its
12901288
// AllocaCommand creation will be dependent on the access mode of the first
12911289
// requirement. Combine these access modes to take all of them into account.

sycl/source/detail/scheduler/scheduler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void Scheduler::waitForRecordToFinish(MemObjRecord *Record,
103103
}
104104

105105
EventImplPtr Scheduler::addCG(
106-
std::unique_ptr<detail::CG> CommandGroup, const QueueImplPtr &Queue,
106+
std::unique_ptr<detail::CG> CommandGroup, queue_impl &Queue,
107107
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer,
108108
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies) {
109109
EventImplPtr NewEvent = nullptr;
@@ -128,7 +128,7 @@ EventImplPtr Scheduler::addCG(
128128
break;
129129
}
130130
default:
131-
NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), std::move(Queue),
131+
NewCmd = MGraphBuilder.addCG(std::move(CommandGroup), &Queue,
132132
AuxiliaryCmds, EventNeeded, CommandBuffer,
133133
std::move(Dependencies));
134134
}
@@ -646,7 +646,7 @@ EventImplPtr Scheduler::addCommandGraphUpdate(
646646
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
647647
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
648648
Nodes,
649-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
649+
queue_impl *Queue, std::vector<Requirement *> Requirements,
650650
std::vector<detail::EventImplPtr> &Events) {
651651
std::vector<Command *> AuxiliaryCmds;
652652
EventImplPtr NewCmdEvent = nullptr;

sycl/source/detail/scheduler/scheduler.hpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ class DispatchHostTask;
187187

188188
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
189189
using EventImplPtr = std::shared_ptr<detail::event_impl>;
190-
using QueueImplPtr = std::shared_ptr<detail::queue_impl>;
191190
using StreamImplPtr = std::shared_ptr<detail::stream_impl>;
192191

193192
using CommandPtr = std::unique_ptr<Command>;
@@ -379,7 +378,7 @@ class Scheduler {
379378
/// \return an event object to wait on for command group completion. It can
380379
/// be a discarded event.
381380
EventImplPtr addCG(
382-
std::unique_ptr<detail::CG> CommandGroup, const QueueImplPtr &Queue,
381+
std::unique_ptr<detail::CG> CommandGroup, queue_impl &Queue,
383382
bool EventNeeded, ur_exp_command_buffer_handle_t CommandBuffer = nullptr,
384383
const std::vector<ur_exp_command_buffer_sync_point_t> &Dependencies = {});
385384

@@ -477,7 +476,7 @@ class Scheduler {
477476
ext::oneapi::experimental::detail::exec_graph_impl *Graph,
478477
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
479478
Nodes,
480-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
479+
queue_impl *Queue, std::vector<Requirement *> Requirements,
481480
std::vector<detail::EventImplPtr> &Events);
482481

483482
static bool CheckEventReadiness(context_impl &Context,
@@ -560,9 +559,8 @@ class Scheduler {
560559
/// \return a command that represents command group execution and a bool
561560
/// indicating whether this command should be enqueued to the graph
562561
/// processor right away or not.
563-
Command *addCG(std::unique_ptr<detail::CG> CommandGroup,
564-
const QueueImplPtr &Queue, std::vector<Command *> &ToEnqueue,
565-
bool EventNeeded,
562+
Command *addCG(std::unique_ptr<detail::CG> CommandGroup, queue_impl *Queue,
563+
std::vector<Command *> &ToEnqueue, bool EventNeeded,
566564
ur_exp_command_buffer_handle_t CommandBuffer = nullptr,
567565
const std::vector<ur_exp_command_buffer_sync_point_t>
568566
&Dependencies = {});
@@ -600,15 +598,15 @@ class Scheduler {
600598
/// used when the user provides a "secondary" queue to the submit method
601599
/// which may be used when the command fails to enqueue/execute in the
602600
/// primary queue.
603-
void rescheduleCommand(Command *Cmd, const QueueImplPtr &Queue);
601+
void rescheduleCommand(Command *Cmd, queue_impl *Queue);
604602

605603
/// \return a pointer to the corresponding memory object record for the
606604
/// SYCL memory object provided, or nullptr if it does not exist.
607605
MemObjRecord *getMemObjRecord(SYCLMemObjI *MemObject);
608606

609607
/// \return a pointer to MemObjRecord for pointer to memory object. If the
610608
/// record is not found, nullptr is returned.
611-
MemObjRecord *getOrInsertMemObjRecord(const QueueImplPtr &Queue,
609+
MemObjRecord *getOrInsertMemObjRecord(queue_impl *Queue,
612610
const Requirement *Req);
613611

614612
/// Decrements leaf counters for all leaves of the record.
@@ -656,7 +654,7 @@ class Scheduler {
656654
std::vector<
657655
std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
658656
Nodes,
659-
const QueueImplPtr &Queue, std::vector<Requirement *> Requirements,
657+
queue_impl *Queue, std::vector<Requirement *> Requirements,
660658
std::vector<detail::EventImplPtr> &Events,
661659
std::vector<Command *> &ToEnqueue);
662660

@@ -673,7 +671,7 @@ class Scheduler {
673671
/// \param Req is a Requirement describing destination.
674672
/// \param Queue is a queue that is bound to target context.
675673
Command *insertMemoryMove(MemObjRecord *Record, Requirement *Req,
676-
const QueueImplPtr &Queue,
674+
queue_impl *Queue,
677675
std::vector<Command *> &ToEnqueue);
678676

679677
// Inserts commands required to remap the memory object to its current host
@@ -684,7 +682,7 @@ class Scheduler {
684682

685683
UpdateHostRequirementCommand *
686684
insertUpdateHostReqCmd(MemObjRecord *Record, Requirement *Req,
687-
const QueueImplPtr &Queue,
685+
queue_impl *Queue,
688686
std::vector<Command *> &ToEnqueue);
689687

690688
/// Finds dependencies for the requirement.
@@ -717,7 +715,7 @@ class Scheduler {
717715
/// If none found, creates new one.
718716
AllocaCommandBase *
719717
getOrCreateAllocaForReq(MemObjRecord *Record, const Requirement *Req,
720-
const QueueImplPtr &Queue,
718+
queue_impl *Queue,
721719
std::vector<Command *> &ToEnqueue);
722720

723721
void markModifiedIfWrite(MemObjRecord *Record, Requirement *Req);

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ event handler::finalize() {
936936
CommandGroup->getRequirements().size() == 0;
937937

938938
detail::EventImplPtr Event = detail::Scheduler::getInstance().addCG(
939-
std::move(CommandGroup), Queue->shared_from_this(), !DiscardEvent);
939+
std::move(CommandGroup), *Queue, !DiscardEvent);
940940

941941
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
942942
MLastEvent = DiscardEvent ? nullptr : Event;

0 commit comments

Comments
 (0)