Skip to content

[SYCL][NFC] Refactor Command::get* member functions #2949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/include/CL/sycl/exception.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class __SYCL_EXPORT exception : public std::exception {
public:
exception() = default;

const char *what() const noexcept final override;
const char *what() const noexcept final;

bool has_context() const;

Expand Down
26 changes: 12 additions & 14 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,8 @@ void Command::makeTraceEventEpilog() {
}

void Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep) {
const ContextImplPtr &Context = getContext();
const QueueImplPtr &WorkerQueue = getWorkerQueue();
const ContextImplPtr &WorkerContext = WorkerQueue->getContextImplPtr();

// 1. Async work is not supported for host device.
// 2. The event handle can be null in case of, for example, alloca command,
Expand All @@ -495,25 +496,24 @@ void Command::processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep) {
}

// Do not add redundant event dependencies for in-order queues.
const QueueImplPtr &WorkerQueue = getWorkerQueue();
if (Dep.MDepCommand && Dep.MDepCommand->getWorkerQueue() == WorkerQueue &&
WorkerQueue->has_property<property::queue::in_order>())
return;

ContextImplPtr DepEventContext = DepEvent->getContextImpl();
// If contexts don't match we'll connect them using host task
if (DepEventContext != Context && !Context->is_host()) {
if (DepEventContext != WorkerContext && !WorkerContext->is_host()) {
Scheduler::GraphBuilder &GB = Scheduler::getInstance().MGraphBuilder;
GB.connectDepEvent(this, DepEvent, Dep);
} else
MPreparedDepsEvents.push_back(std::move(DepEvent));
}

ContextImplPtr Command::getContext() const {
return detail::getSyclObjImpl(MQueue->get_context());
const ContextImplPtr &Command::getWorkerContext() const {
return MQueue->getContextImplPtr();
}

QueueImplPtr Command::getWorkerQueue() const { return MQueue; }
const QueueImplPtr &Command::getWorkerQueue() const { return MQueue; }

void Command::addDep(DepDesc NewDep) {
if (NewDep.MDepCommand) {
Expand Down Expand Up @@ -1135,12 +1135,11 @@ void MemCpyCommand::emitInstrumentationData() {
#endif
}

ContextImplPtr MemCpyCommand::getContext() const {
const QueueImplPtr &Queue = getWorkerQueue();
return detail::getSyclObjImpl(Queue->get_context());
const ContextImplPtr &MemCpyCommand::getWorkerContext() const {
return getWorkerQueue()->getContextImplPtr();
}

QueueImplPtr MemCpyCommand::getWorkerQueue() const {
const QueueImplPtr &MemCpyCommand::getWorkerQueue() const {
return MQueue->is_host() ? MSrcQueue : MQueue;
}

Expand Down Expand Up @@ -1276,12 +1275,11 @@ void MemCpyCommandHost::emitInstrumentationData() {
#endif
}

ContextImplPtr MemCpyCommandHost::getContext() const {
const QueueImplPtr &Queue = getWorkerQueue();
return detail::getSyclObjImpl(Queue->get_context());
const ContextImplPtr &MemCpyCommandHost::getWorkerContext() const {
return getWorkerQueue()->getContextImplPtr();
}

QueueImplPtr MemCpyCommandHost::getWorkerQueue() const {
const QueueImplPtr &MemCpyCommandHost::getWorkerQueue() const {
return MQueue->is_host() ? MSrcQueue : MQueue;
}

Expand Down
88 changes: 45 additions & 43 deletions sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ class Command {
return MEnqueueStatus == EnqueueResultT::SyclEnqueueBlocked;
}

std::shared_ptr<queue_impl> getQueue() const { return MQueue; }
const QueueImplPtr &getQueue() const { return MQueue; }

std::shared_ptr<event_impl> getEvent() const { return MEvent; }
const EventImplPtr &getEvent() const { return MEvent; }

// Methods needed to support SYCL instrumentation

Expand Down Expand Up @@ -179,11 +179,13 @@ class Command {

const char *getBlockReason() const;

virtual ContextImplPtr getContext() const;
/// Get the context of the queue this command will be submitted to. Could
/// differ from the context of MQueue for memory copy commands.
virtual const ContextImplPtr &getWorkerContext() const;

/// Get the queue this command will be submitted to. Could differ from MQueue
/// for memory copy commands.
virtual QueueImplPtr getWorkerQueue() const;
virtual const QueueImplPtr &getWorkerQueue() const;

protected:
EventImplPtr MEvent;
Expand All @@ -205,7 +207,7 @@ class Command {
///
/// Glueing (i.e. connecting) will be performed if and only if DepEvent is
/// not from host context and its context doesn't match to context of this
/// command. Context of this command is fetched via getContext().
/// command. Context of this command is fetched via getWorkerContext().
///
/// Optionality of Dep is set by Dep.MDepCommand not equal to nullptr.
void processDepEvent(EventImplPtr DepEvent, const DepDesc &Dep);
Expand All @@ -221,7 +223,7 @@ class Command {
friend class DispatchHostTask;

public:
const std::vector<EventImplPtr> getPreparedHostDepsEvents() const {
const std::vector<EventImplPtr> &getPreparedHostDepsEvents() const {
return MPreparedHostDepsEvents;
}

Expand Down Expand Up @@ -293,15 +295,15 @@ class EmptyCommand : public Command {
public:
EmptyCommand(QueueImplPtr Queue);

void printDot(std::ostream &Stream) const final override;
const Requirement *getRequirement() const final override { return &MRequirements[0]; }
void printDot(std::ostream &Stream) const final;
const Requirement *getRequirement() const final { return &MRequirements[0]; }
void addRequirement(Command *DepCmd, AllocaCommandBase *AllocaCmd,
const Requirement *Req);

void emitInstrumentationData() override;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

// Employing deque here as it allows to push_back/emplace_back without
// invalidation of pointer or reference to stored data item regardless of
Expand All @@ -315,11 +317,11 @@ class ReleaseCommand : public Command {
public:
ReleaseCommand(QueueImplPtr Queue, AllocaCommandBase *AllocaCmd);

void printDot(std::ostream &Stream) const final override;
void printDot(std::ostream &Stream) const final;
void emitInstrumentationData() override;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

/// Command which allocates memory release command should dealocate.
AllocaCommandBase *MAllocaCmd = nullptr;
Expand All @@ -337,7 +339,7 @@ class AllocaCommandBase : public Command {

virtual void *getMemAllocation() const = 0;

const Requirement *getRequirement() const final override { return &MRequirement; }
const Requirement *getRequirement() const final { return &MRequirement; }

void emitInstrumentationData() override;

Expand Down Expand Up @@ -369,12 +371,12 @@ class AllocaCommand : public AllocaCommandBase {
bool InitFromUserData = true,
AllocaCommandBase *LinkedAllocaCmd = nullptr);

void *getMemAllocation() const final override { return MMemAllocation; }
void printDot(std::ostream &Stream) const final override;
void *getMemAllocation() const final { return MMemAllocation; }
void printDot(std::ostream &Stream) const final;
void emitInstrumentationData() override;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

/// The flag indicates that alloca should try to reuse pointer provided by
/// the user during memory object construction.
Expand All @@ -387,13 +389,13 @@ class AllocaSubBufCommand : public AllocaCommandBase {
AllocaSubBufCommand(QueueImplPtr Queue, Requirement Req,
AllocaCommandBase *ParentAlloca);

void *getMemAllocation() const final override;
void printDot(std::ostream &Stream) const final override;
void *getMemAllocation() const final;
void printDot(std::ostream &Stream) const final;
AllocaCommandBase *getParentAlloca() { return MParentAlloca; }
void emitInstrumentationData() override;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

AllocaCommandBase *MParentAlloca = nullptr;
};
Expand All @@ -404,12 +406,12 @@ class MapMemObject : public Command {
MapMemObject(AllocaCommandBase *SrcAllocaCmd, Requirement Req, void **DstPtr,
QueueImplPtr Queue, access::mode MapMode);

void printDot(std::ostream &Stream) const final override;
const Requirement *getRequirement() const final override { return &MSrcReq; }
void printDot(std::ostream &Stream) const final;
const Requirement *getRequirement() const final { return &MSrcReq; }
void emitInstrumentationData() override;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

AllocaCommandBase *MSrcAllocaCmd = nullptr;
Requirement MSrcReq;
Expand All @@ -423,12 +425,12 @@ class UnMapMemObject : public Command {
UnMapMemObject(AllocaCommandBase *DstAllocaCmd, Requirement Req,
void **SrcPtr, QueueImplPtr Queue);

void printDot(std::ostream &Stream) const final override;
const Requirement *getRequirement() const final override { return &MDstReq; }
void printDot(std::ostream &Stream) const final;
const Requirement *getRequirement() const final { return &MDstReq; }
void emitInstrumentationData() override;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

AllocaCommandBase *MDstAllocaCmd = nullptr;
Requirement MDstReq;
Expand All @@ -443,14 +445,14 @@ class MemCpyCommand : public Command {
Requirement DstReq, AllocaCommandBase *DstAllocaCmd,
QueueImplPtr SrcQueue, QueueImplPtr DstQueue);

void printDot(std::ostream &Stream) const final override;
const Requirement *getRequirement() const final override { return &MDstReq; }
void emitInstrumentationData() final override;
ContextImplPtr getContext() const final override;
QueueImplPtr getWorkerQueue() const final override;
void printDot(std::ostream &Stream) const final;
const Requirement *getRequirement() const final { return &MDstReq; }
void emitInstrumentationData() final;
const ContextImplPtr &getWorkerContext() const final;
const QueueImplPtr &getWorkerQueue() const final;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

QueueImplPtr MSrcQueue;
Requirement MSrcReq;
Expand All @@ -467,14 +469,14 @@ class MemCpyCommandHost : public Command {
Requirement DstReq, void **DstPtr, QueueImplPtr SrcQueue,
QueueImplPtr DstQueue);

void printDot(std::ostream &Stream) const final override;
const Requirement *getRequirement() const final override { return &MDstReq; }
void emitInstrumentationData() final override;
ContextImplPtr getContext() const final override;
QueueImplPtr getWorkerQueue() const final override;
void printDot(std::ostream &Stream) const final;
const Requirement *getRequirement() const final { return &MDstReq; }
void emitInstrumentationData() final;
const ContextImplPtr &getWorkerContext() const final;
const QueueImplPtr &getWorkerQueue() const final;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

QueueImplPtr MSrcQueue;
Requirement MSrcReq;
Expand All @@ -493,8 +495,8 @@ class ExecCGCommand : public Command {

void clearStreams();

void printDot(std::ostream &Stream) const final override;
void emitInstrumentationData() final override;
void printDot(std::ostream &Stream) const final;
void emitInstrumentationData() final;

detail::CG &getCG() const { return *MCommandGroup; }

Expand All @@ -512,7 +514,7 @@ class ExecCGCommand : public Command {
}

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

AllocaCommandBase *getAllocaForReq(Requirement *Req);

Expand All @@ -531,12 +533,12 @@ class UpdateHostRequirementCommand : public Command {
UpdateHostRequirementCommand(QueueImplPtr Queue, Requirement Req,
AllocaCommandBase *SrcAllocaCmd, void **DstPtr);

void printDot(std::ostream &Stream) const final override;
const Requirement *getRequirement() const final override { return &MDstReq; }
void emitInstrumentationData() final override;
void printDot(std::ostream &Stream) const final;
const Requirement *getRequirement() const final { return &MDstReq; }
void emitInstrumentationData() final;

private:
cl_int enqueueImp() final override;
cl_int enqueueImp() final;

AllocaCommandBase *MSrcAllocaCmd = nullptr;
Requirement MDstReq;
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ void Scheduler::GraphBuilder::removeRecordForMemObj(SYCLMemObjI *MemObject) {
void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
EventImplPtr DepEvent,
const DepDesc &Dep) {
assert(Cmd->getContext() != DepEvent->getContextImpl());
assert(Cmd->getWorkerContext() != DepEvent->getContextImpl());

// construct Host Task type command manually and make it depend on DepEvent
ExecCGCommand *ConnectCmd = nullptr;
Expand Down