Skip to content

Commit f621a20

Browse files
[SYCL][NFC] Refactor Command::get* member functions (#2949)
Including the following changes: - Return by const reference where possible - Remove redundant override keywords (applied globally) - Rename getContext -> getWorkerContext to mirror getWorkerQueue and better reflect potential discrepancy between the returned context and the context of MQueue.
1 parent 2249491 commit f621a20

File tree

4 files changed

+59
-59
lines changed

4 files changed

+59
-59
lines changed

sycl/include/CL/sycl/exception.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class __SYCL_EXPORT exception : public std::exception {
3030
public:
3131
exception() = default;
3232

33-
const char *what() const noexcept final override;
33+
const char *what() const noexcept final;
3434

3535
bool has_context() const;
3636

sycl/source/detail/scheduler/commands.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,8 @@ void Command::makeTraceEventEpilog() {
481481
}
482482

483483
void Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep) {
484-
const ContextImplPtr &Context = getContext();
484+
const QueueImplPtr &WorkerQueue = getWorkerQueue();
485+
const ContextImplPtr &WorkerContext = WorkerQueue->getContextImplPtr();
485486

486487
// 1. Async work is not supported for host device.
487488
// 2. The event handle can be null in case of, for example, alloca command,
@@ -495,25 +496,24 @@ void Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep) {
495496
}
496497

497498
// Do not add redundant event dependencies for in-order queues.
498-
const QueueImplPtr &WorkerQueue = getWorkerQueue();
499499
if (Dep.MDepCommand && Dep.MDepCommand->getWorkerQueue() == WorkerQueue &&
500500
WorkerQueue->has_property<property::queue::in_order>())
501501
return;
502502

503503
ContextImplPtr DepEventContext = DepEvent->getContextImpl();
504504
// If contexts don't match we'll connect them using host task
505-
if (DepEventContext != Context && !Context->is_host()) {
505+
if (DepEventContext != WorkerContext && !WorkerContext->is_host()) {
506506
Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder;
507507
GB.connectDepEvent(this, DepEvent, Dep);
508508
} else
509509
MPreparedDepsEvents.push_back(std::move(DepEvent));
510510
}
511511

512-
ContextImplPtr Command::getContext() const {
513-
return detail::getSyclObjImpl(MQueue->get_context());
512+
const ContextImplPtr &Command::getWorkerContext() const {
513+
return MQueue->getContextImplPtr();
514514
}
515515

516-
QueueImplPtr Command::getWorkerQueue() const { return MQueue; }
516+
const QueueImplPtr &Command::getWorkerQueue() const { return MQueue; }
517517

518518
void Command::addDep(DepDesc NewDep) {
519519
if (NewDep.MDepCommand) {
@@ -1135,12 +1135,11 @@ void MemCpyCommand::emitInstrumentationData() {
11351135
#endif
11361136
}
11371137

1138-
ContextImplPtr MemCpyCommand::getContext() const {
1139-
const QueueImplPtr &Queue = getWorkerQueue();
1140-
return detail::getSyclObjImpl(Queue->get_context());
1138+
const ContextImplPtr &MemCpyCommand::getWorkerContext() const {
1139+
return getWorkerQueue()->getContextImplPtr();
11411140
}
11421141

1143-
QueueImplPtr MemCpyCommand::getWorkerQueue() const {
1142+
const QueueImplPtr &MemCpyCommand::getWorkerQueue() const {
11441143
return MQueue->is_host() ? MSrcQueue : MQueue;
11451144
}
11461145

@@ -1276,12 +1275,11 @@ void MemCpyCommandHost::emitInstrumentationData() {
12761275
#endif
12771276
}
12781277

1279-
ContextImplPtr MemCpyCommandHost::getContext() const {
1280-
const QueueImplPtr &Queue = getWorkerQueue();
1281-
return detail::getSyclObjImpl(Queue->get_context());
1278+
const ContextImplPtr &MemCpyCommandHost::getWorkerContext() const {
1279+
return getWorkerQueue()->getContextImplPtr();
12821280
}
12831281

1284-
QueueImplPtr MemCpyCommandHost::getWorkerQueue() const {
1282+
const QueueImplPtr &MemCpyCommandHost::getWorkerQueue() const {
12851283
return MQueue->is_host() ? MSrcQueue : MQueue;
12861284
}
12871285

sycl/source/detail/scheduler/commands.hpp

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ class Command {
134134
return MEnqueueStatus == EnqueueResultT::SyclEnqueueBlocked;
135135
}
136136

137-
std::shared_ptr<queue_impl> getQueue() const { return MQueue; }
137+
const QueueImplPtr &getQueue() const { return MQueue; }
138138

139-
std::shared_ptr<event_impl> getEvent() const { return MEvent; }
139+
const EventImplPtr &getEvent() const { return MEvent; }
140140

141141
// Methods needed to support SYCL instrumentation
142142

@@ -179,11 +179,13 @@ class Command {
179179

180180
const char *getBlockReason() const;
181181

182-
virtual ContextImplPtr getContext() const;
182+
/// Get the context of the queue this command will be submitted to. Could
183+
/// differ from the context of MQueue for memory copy commands.
184+
virtual const ContextImplPtr &getWorkerContext() const;
183185

184186
/// Get the queue this command will be submitted to. Could differ from MQueue
185187
/// for memory copy commands.
186-
virtual QueueImplPtr getWorkerQueue() const;
188+
virtual const QueueImplPtr &getWorkerQueue() const;
187189

188190
protected:
189191
EventImplPtr MEvent;
@@ -205,7 +207,7 @@ class Command {
205207
///
206208
/// Glueing (i.e. connecting) will be performed if and only if DepEvent is
207209
/// not from host context and its context doesn't match to context of this
208-
/// command. Context of this command is fetched via getContext().
210+
/// command. Context of this command is fetched via getWorkerContext().
209211
///
210212
/// Optionality of Dep is set by Dep.MDepCommand not equal to nullptr.
211213
void processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep);
@@ -221,7 +223,7 @@ class Command {
221223
friend class DispatchHostTask;
222224

223225
public:
224-
const std::vector<EventImplPtr> getPreparedHostDepsEvents() const {
226+
const std::vector<EventImplPtr> &getPreparedHostDepsEvents() const {
225227
return MPreparedHostDepsEvents;
226228
}
227229

@@ -293,15 +295,15 @@ class EmptyCommand : public Command {
293295
public:
294296
EmptyCommand(QueueImplPtr Queue);
295297

296-
void printDot(std::ostream &Stream) const final override;
297-
const Requirement *getRequirement() const final override { return &MRequirements[0]; }
298+
void printDot(std::ostream &Stream) const final;
299+
const Requirement *getRequirement() const final { return &MRequirements[0]; }
298300
void addRequirement(Command *DepCmd, AllocaCommandBase *AllocaCmd,
299301
const Requirement *Req);
300302

301303
void emitInstrumentationData() override;
302304

303305
private:
304-
cl_int enqueueImp() final override;
306+
cl_int enqueueImp() final;
305307

306308
// Employing deque here as it allows to push_back/emplace_back without
307309
// invalidation of pointer or reference to stored data item regardless of
@@ -315,11 +317,11 @@ class ReleaseCommand : public Command {
315317
public:
316318
ReleaseCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd);
317319

318-
void printDot(std::ostream &Stream) const final override;
320+
void printDot(std::ostream &Stream) const final;
319321
void emitInstrumentationData() override;
320322

321323
private:
322-
cl_int enqueueImp() final override;
324+
cl_int enqueueImp() final;
323325

324326
/// Command which allocates memory release command should dealocate.
325327
AllocaCommandBase *MAllocaCmd = nullptr;
@@ -337,7 +339,7 @@ class AllocaCommandBase : public Command {
337339

338340
virtual void *getMemAllocation() const = 0;
339341

340-
const Requirement *getRequirement() const final override { return &MRequirement; }
342+
const Requirement *getRequirement() const final { return &MRequirement; }
341343

342344
void emitInstrumentationData() override;
343345

@@ -369,12 +371,12 @@ class AllocaCommand : public AllocaCommandBase {
369371
bool InitFromUserData = true,
370372
AllocaCommandBase *LinkedAllocaCmd = nullptr);
371373

372-
void *getMemAllocation() const final override { return MMemAllocation; }
373-
void printDot(std::ostream &Stream) const final override;
374+
void *getMemAllocation() const final { return MMemAllocation; }
375+
void printDot(std::ostream &Stream) const final;
374376
void emitInstrumentationData() override;
375377

376378
private:
377-
cl_int enqueueImp() final override;
379+
cl_int enqueueImp() final;
378380

379381
/// The flag indicates that alloca should try to reuse pointer provided by
380382
/// the user during memory object construction.
@@ -387,13 +389,13 @@ class AllocaSubBufCommand : public AllocaCommandBase {
387389
AllocaSubBufCommand(QueueImplPtr Queue, Requirement Req,
388390
AllocaCommandBase *ParentAlloca);
389391

390-
void *getMemAllocation() const final override;
391-
void printDot(std::ostream &Stream) const final override;
392+
void *getMemAllocation() const final;
393+
void printDot(std::ostream &Stream) const final;
392394
AllocaCommandBase *getParentAlloca() { return MParentAlloca; }
393395
void emitInstrumentationData() override;
394396

395397
private:
396-
cl_int enqueueImp() final override;
398+
cl_int enqueueImp() final;
397399

398400
AllocaCommandBase *MParentAlloca = nullptr;
399401
};
@@ -404,12 +406,12 @@ class MapMemObject : public Command {
404406
MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req, void **DstPtr,
405407
QueueImplPtr Queue, access::mode MapMode);
406408

407-
void printDot(std::ostream &Stream) const final override;
408-
const Requirement *getRequirement() const final override { return &MSrcReq; }
409+
void printDot(std::ostream &Stream) const final;
410+
const Requirement *getRequirement() const final { return &MSrcReq; }
409411
void emitInstrumentationData() override;
410412

411413
private:
412-
cl_int enqueueImp() final override;
414+
cl_int enqueueImp() final;
413415

414416
AllocaCommandBase *MSrcAllocaCmd = nullptr;
415417
Requirement MSrcReq;
@@ -423,12 +425,12 @@ class UnMapMemObject : public Command {
423425
UnMapMemObject(AllocaCommandBase *DstAllocaCmd, Requirement Req,
424426
void **SrcPtr, QueueImplPtr Queue);
425427

426-
void printDot(std::ostream &Stream) const final override;
427-
const Requirement *getRequirement() const final override { return &MDstReq; }
428+
void printDot(std::ostream &Stream) const final;
429+
const Requirement *getRequirement() const final { return &MDstReq; }
428430
void emitInstrumentationData() override;
429431

430432
private:
431-
cl_int enqueueImp() final override;
433+
cl_int enqueueImp() final;
432434

433435
AllocaCommandBase *MDstAllocaCmd = nullptr;
434436
Requirement MDstReq;
@@ -443,14 +445,14 @@ class MemCpyCommand : public Command {
443445
Requirement DstReq, AllocaCommandBase *DstAllocaCmd,
444446
QueueImplPtr SrcQueue, QueueImplPtr DstQueue);
445447

446-
void printDot(std::ostream &Stream) const final override;
447-
const Requirement *getRequirement() const final override { return &MDstReq; }
448-
void emitInstrumentationData() final override;
449-
ContextImplPtr getContext() const final override;
450-
QueueImplPtr getWorkerQueue() const final override;
448+
void printDot(std::ostream &Stream) const final;
449+
const Requirement *getRequirement() const final { return &MDstReq; }
450+
void emitInstrumentationData() final;
451+
const ContextImplPtr &getWorkerContext() const final;
452+
const QueueImplPtr &getWorkerQueue() const final;
451453

452454
private:
453-
cl_int enqueueImp() final override;
455+
cl_int enqueueImp() final;
454456

455457
QueueImplPtr MSrcQueue;
456458
Requirement MSrcReq;
@@ -467,14 +469,14 @@ class MemCpyCommandHost : public Command {
467469
Requirement DstReq, void **DstPtr, QueueImplPtr SrcQueue,
468470
QueueImplPtr DstQueue);
469471

470-
void printDot(std::ostream &Stream) const final override;
471-
const Requirement *getRequirement() const final override { return &MDstReq; }
472-
void emitInstrumentationData() final override;
473-
ContextImplPtr getContext() const final override;
474-
QueueImplPtr getWorkerQueue() const final override;
472+
void printDot(std::ostream &Stream) const final;
473+
const Requirement *getRequirement() const final { return &MDstReq; }
474+
void emitInstrumentationData() final;
475+
const ContextImplPtr &getWorkerContext() const final;
476+
const QueueImplPtr &getWorkerQueue() const final;
475477

476478
private:
477-
cl_int enqueueImp() final override;
479+
cl_int enqueueImp() final;
478480

479481
QueueImplPtr MSrcQueue;
480482
Requirement MSrcReq;
@@ -493,8 +495,8 @@ class ExecCGCommand : public Command {
493495

494496
void clearStreams();
495497

496-
void printDot(std::ostream &Stream) const final override;
497-
void emitInstrumentationData() final override;
498+
void printDot(std::ostream &Stream) const final;
499+
void emitInstrumentationData() final;
498500

499501
detail::CG &getCG() const { return *MCommandGroup; }
500502

@@ -512,7 +514,7 @@ class ExecCGCommand : public Command {
512514
}
513515

514516
private:
515-
cl_int enqueueImp() final override;
517+
cl_int enqueueImp() final;
516518

517519
AllocaCommandBase *getAllocaForReq(Requirement *Req);
518520

@@ -531,12 +533,12 @@ class UpdateHostRequirementCommand : public Command {
531533
UpdateHostRequirementCommand(QueueImplPtr Queue, Requirement Req,
532534
AllocaCommandBase *SrcAllocaCmd, void **DstPtr);
533535

534-
void printDot(std::ostream &Stream) const final override;
535-
const Requirement *getRequirement() const final override { return &MDstReq; }
536-
void emitInstrumentationData() final override;
536+
void printDot(std::ostream &Stream) const final;
537+
const Requirement *getRequirement() const final { return &MDstReq; }
538+
void emitInstrumentationData() final;
537539

538540
private:
539-
cl_int enqueueImp() final override;
541+
cl_int enqueueImp() final;
540542

541543
AllocaCommandBase *MSrcAllocaCmd = nullptr;
542544
Requirement MDstReq;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
10941094
void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
10951095
EventImplPtr DepEvent,
10961096
const DepDesc &Dep) {
1097-
assert(Cmd->getContext() != DepEvent->getContextImpl());
1097+
assert(Cmd->getWorkerContext() != DepEvent->getContextImpl());
10981098

10991099
// construct Host Task type command manually and make it depend on DepEvent
11001100
ExecCGCommand *ConnectCmd = nullptr;

0 commit comments

Comments
 (0)