Skip to content

Commit 6f3b4d7

Browse files
author
Sergey Kanaev
committed
[SYCL] Eliminate const_cast
Signed-off-by: Sergey Kanaev <[email protected]>
1 parent 692bf79 commit 6f3b4d7

File tree

2 files changed

+39
-34
lines changed

2 files changed

+39
-34
lines changed

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ MemObjRecord *Scheduler::GraphBuilder::getMemObjRecord(SYCLMemObjI *MemObject) {
129129

130130
MemObjRecord *
131131
Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
132-
Requirement *Req) {
132+
const Requirement *Req) {
133133
SYCLMemObjI *MemObject = Req->MSYCLMemObj;
134134
MemObjRecord *Record = getMemObjRecord(MemObject);
135135

@@ -416,8 +416,8 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
416416
Command *UpdateHostAccCmd = insertUpdateHostReqCmd(Record, Req, HostQueue);
417417

418418
// Need empty command to be blocked until host accessor is destructed
419-
EmptyCommand *EmptyCmd = addEmptyCmd(UpdateHostAccCmd, {Req}, HostQueue,
420-
Command::BlockReason::HostAccessor);
419+
EmptyCommand *EmptyCmd = addEmptyCmd<Requirement>(
420+
UpdateHostAccCmd, {Req}, HostQueue, Command::BlockReason::HostAccessor);
421421

422422
Req->MBlockedCmd = EmptyCmd;
423423

@@ -446,7 +446,7 @@ Command *Scheduler::GraphBuilder::addCGUpdateHost(
446446
/// 2. New and examined commands has non-overlapping requirements -> can bypass
447447
/// 3. New and examined commands have different contexts -> cannot bypass
448448
std::set<Command *>
449-
Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record, Requirement *Req,
449+
Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record, const Requirement *Req,
450450
const ContextImplPtr &Context) {
451451
std::set<Command *> RetDeps;
452452
std::set<Command *> Visited;
@@ -514,7 +514,7 @@ DepDesc Scheduler::GraphBuilder::findDepForRecord(Command *Cmd,
514514
// The function searches for the alloca command matching context and
515515
// requirement.
516516
AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
517-
MemObjRecord *Record, Requirement *Req, const ContextImplPtr &Context) {
517+
MemObjRecord *Record, const Requirement *Req, const ContextImplPtr &Context) {
518518
auto IsSuitableAlloca = [&Context, Req](AllocaCommandBase *AllocaCmd) {
519519
bool Res = sameCtx(AllocaCmd->getQueue()->getContextImplPtr(), Context);
520520
if (IsSuitableSubReq(Req)) {
@@ -535,7 +535,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
535535
// Note, creation of new allocation command can lead to the current context
536536
// (Record->MCurContext) change.
537537
AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
538-
MemObjRecord *Record, Requirement *Req, QueueImplPtr Queue) {
538+
MemObjRecord *Record, const Requirement *Req, QueueImplPtr Queue) {
539539

540540
AllocaCommandBase *AllocaCmd =
541541
findAllocaForReq(Record, Req, Queue->getContextImplPtr());
@@ -640,9 +640,14 @@ void Scheduler::GraphBuilder::markModifiedIfWrite(MemObjRecord *Record,
640640
}
641641
}
642642

643-
EmptyCommand *Scheduler::GraphBuilder::addEmptyCmd(
644-
Command *Cmd, const std::vector<Requirement *> &Reqs,
645-
const QueueImplPtr &Queue, Command::BlockReason Reason) {
643+
template<typename T>
644+
typename std::enable_if<std::is_same<typename std::remove_cv<T>::type,
645+
Requirement>::value,
646+
EmptyCommand *>::type
647+
Scheduler::GraphBuilder::addEmptyCmd(Command *Cmd,
648+
const std::vector<T *> &Reqs,
649+
const QueueImplPtr &Queue,
650+
Command::BlockReason Reason) {
646651
EmptyCommand *EmptyCmd =
647652
new EmptyCommand(Scheduler::getInstance().getDefaultHostQueue());
648653

@@ -653,7 +658,7 @@ EmptyCommand *Scheduler::GraphBuilder::addEmptyCmd(
653658
EmptyCmd->MEnqueueStatus = EnqueueResultT::SyclEnqueueBlocked;
654659
EmptyCmd->MBlockReason = Reason;
655660

656-
for (Requirement *Req : Reqs) {
661+
for (T *Req : Reqs) {
657662
MemObjRecord *Record = getOrInsertMemObjRecord(Queue, Req);
658663
AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue);
659664
EmptyCmd->addRequirement(Cmd, AllocaCmd, Req);
@@ -941,23 +946,19 @@ void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
941946
EmptyCommand *EmptyCmd = nullptr;
942947

943948
if (Dep.MDepRequirement) {
944-
Requirement *Req = const_cast<Requirement *>(Dep.MDepRequirement);
945-
946949
// make ConnectCmd depend on requirement
947-
{
948-
ConnectCmd->addDep(Dep);
949-
assert(reinterpret_cast<Command *>(DepEvent->getCommand()) ==
950-
Dep.MDepCommand);
951-
// add user to Dep.MDepCommand is already performed beyond this if branch
950+
ConnectCmd->addDep(Dep);
951+
assert(reinterpret_cast<Command *>(DepEvent->getCommand()) ==
952+
Dep.MDepCommand);
953+
// add user to Dep.MDepCommand is already performed beyond this if branch
952954

953-
MemObjRecord *Record = getMemObjRecord(Req->MSYCLMemObj);
955+
MemObjRecord *Record = getMemObjRecord(Dep.MDepRequirement->MSYCLMemObj);
954956

955-
updateLeaves({ Dep.MDepCommand }, Record, Req->MAccessMode);
956-
addNodeToLeaves(Record, ConnectCmd, Req->MAccessMode);
957-
}
957+
updateLeaves({ Dep.MDepCommand }, Record, Dep.MDepRequirement->MAccessMode);
958+
addNodeToLeaves(Record, ConnectCmd, Dep.MDepRequirement->MAccessMode);
958959

959-
const std::vector<Requirement *> Reqs(1, Req);
960-
EmptyCmd = addEmptyCmd(ConnectCmd, Reqs,
960+
const std::vector<const Requirement *> Reqs(1, Dep.MDepRequirement);
961+
EmptyCmd = addEmptyCmd<>(ConnectCmd, Reqs,
961962
Scheduler::getInstance().getDefaultHostQueue(),
962963
Command::BlockReason::HostTask);
963964
// Dependencies for EmptyCmd are set in addEmptyCmd for provided Reqs.
@@ -970,9 +971,9 @@ void Scheduler::GraphBuilder::connectDepEvent(Command *const Cmd,
970971
Cmd->addDep(CmdDep);
971972
}
972973
} else {
973-
EmptyCmd = addEmptyCmd(ConnectCmd, {},
974-
Scheduler::getInstance().getDefaultHostQueue(),
975-
Command::BlockReason::HostTask);
974+
EmptyCmd = addEmptyCmd<Requirement>(
975+
ConnectCmd, {}, Scheduler::getInstance().getDefaultHostQueue(),
976+
Command::BlockReason::HostTask);
976977

977978
// There is no requirement thus, empty command will only depend on
978979
// ConnectCmd via its event.

sycl/source/detail/scheduler/scheduler.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ class Scheduler {
493493
/// \return a pointer to MemObjRecord for pointer to memory object. If the
494494
/// record is not found, nullptr is returned.
495495
MemObjRecord *getOrInsertMemObjRecord(const QueueImplPtr &Queue,
496-
Requirement *Req);
496+
const Requirement *Req);
497497

498498
/// Decrements leaf counters for all leaves of the record.
499499
void decrementLeafCountersForRecord(MemObjRecord *Record);
@@ -546,20 +546,24 @@ class Scheduler {
546546
const QueueImplPtr &Queue);
547547

548548
/// Finds dependencies for the requirement.
549-
std::set<Command *> findDepsForReq(MemObjRecord *Record, Requirement *Req,
549+
std::set<Command *> findDepsForReq(MemObjRecord *Record,
550+
const Requirement *Req,
550551
const ContextImplPtr &Context);
551552

552-
EmptyCommand *addEmptyCmd(Command *Cmd,
553-
const std::vector<Requirement *> &Req,
554-
const QueueImplPtr &Queue,
555-
Command::BlockReason Reason);
553+
template<typename T>
554+
typename std::enable_if<std::is_same<typename std::remove_cv<T>::type,
555+
Requirement>::value,
556+
EmptyCommand *>::type
557+
addEmptyCmd(Command *Cmd, const std::vector<T *> &Req,
558+
const QueueImplPtr &Queue, Command::BlockReason Reason);
556559

557560
protected:
558561
/// Finds a command dependency corresponding to the record.
559562
DepDesc findDepForRecord(Command *Cmd, MemObjRecord *Record);
560563

561564
/// Searches for suitable alloca in memory record.
562-
AllocaCommandBase *findAllocaForReq(MemObjRecord *Record, Requirement *Req,
565+
AllocaCommandBase *findAllocaForReq(MemObjRecord *Record,
566+
const Requirement *Req,
563567
const ContextImplPtr &Context);
564568

565569
friend class Command;
@@ -569,7 +573,7 @@ class Scheduler {
569573
///
570574
/// If none found, creates new one.
571575
AllocaCommandBase *getOrCreateAllocaForReq(MemObjRecord *Record,
572-
Requirement *Req,
576+
const Requirement *Req,
573577
QueueImplPtr Queue);
574578

575579
void markModifiedIfWrite(MemObjRecord *Record, Requirement *Req);

0 commit comments

Comments
 (0)