Skip to content

Commit ea4d084

Browse files
[NFC][SYCL] Use plain context_impl * for func params in scheduler (#19156)
Continuation of the refactoring in #18795 #18877 #18966 #18979 #18980 #18981 #19007 #19030 #19123 #19126
1 parent 3a61a20 commit ea4d084

File tree

4 files changed

+49
-48
lines changed

4 files changed

+49
-48
lines changed

sycl/source/detail/queue_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
646646
// for in order ones.
647647
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);
648648

649-
static ContextImplPtr getContext(queue_impl *Queue) {
650-
return Queue ? Queue->getContextImplPtr() : nullptr;
649+
static context_impl *getContext(queue_impl *Queue) {
650+
return Queue ? &Queue->getContextImpl() : nullptr;
651651
}
652652

653653
// Must be called under MMutex protection

sycl/source/detail/scheduler/commands.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ class DispatchHostTask {
429429
"Host task submissions should have an associated queue");
430430
interop_handle IH{MReqToMem, HostTask.MQueue,
431431
HostTask.MQueue->getDeviceImpl().shared_from_this(),
432-
HostTask.MQueue->getContextImplPtr()};
432+
HostTask.MQueue->getContextImpl().shared_from_this()};
433433
// TODO: should all the backends that support this entry point use this
434434
// for host task?
435435
auto &Queue = HostTask.MQueue;
@@ -2676,7 +2676,7 @@ void enqueueImpKernel(
26762676
detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
26772677
bool KernelHasSpecialCaptures) {
26782678
// Run OpenCL kernel
2679-
auto &ContextImpl = Queue.getContextImplPtr();
2679+
context_impl &ContextImpl = Queue.getContextImpl();
26802680
device_impl &DeviceImpl = Queue.getDeviceImpl();
26812681
ur_kernel_handle_t Kernel = nullptr;
26822682
std::mutex *KernelMutex = nullptr;
@@ -2714,7 +2714,7 @@ void enqueueImpKernel(
27142714
KernelMutex = SyclKernelImpl->getCacheMutex();
27152715
} else {
27162716
KernelCacheVal = detail::ProgramManager::getInstance().getOrCreateKernel(
2717-
*ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc);
2717+
ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc);
27182718
Kernel = KernelCacheVal->MKernelHandle;
27192719
KernelMutex = KernelCacheVal->MMutex;
27202720
Program = KernelCacheVal->MProgramHandle;
@@ -2726,7 +2726,7 @@ void enqueueImpKernel(
27262726

27272727
// Initialize device globals associated with this.
27282728
std::vector<ur_event_handle_t> DeviceGlobalInitEvents =
2729-
ContextImpl->initializeDeviceGlobals(Program, Queue);
2729+
ContextImpl.initializeDeviceGlobals(Program, Queue);
27302730
if (!DeviceGlobalInitEvents.empty()) {
27312731
std::vector<ur_event_handle_t> EventsWithDeviceGlobalInits;
27322732
EventsWithDeviceGlobalInits.reserve(RawEvents.size() +
@@ -2783,9 +2783,9 @@ ur_result_t enqueueReadWriteHostPipe(queue_impl &Queue,
27832783

27842784
ur_program_handle_t Program = nullptr;
27852785
device Device = Queue.get_device();
2786-
ContextImplPtr ContextImpl = Queue.getContextImplPtr();
2786+
context_impl &ContextImpl = Queue.getContextImpl();
27872787
std::optional<ur_program_handle_t> CachedProgram =
2788-
ContextImpl->getProgramForHostPipe(Device, hostPipeEntry);
2788+
ContextImpl.getProgramForHostPipe(Device, hostPipeEntry);
27892789
if (CachedProgram)
27902790
Program = *CachedProgram;
27912791
else {
@@ -3003,7 +3003,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
30033003
// Queue is created by graph_impl before creating command to submit to
30043004
// scheduler.
30053005
const AdapterPtr &Adapter = MQueue->getAdapter();
3006-
auto ContextImpl = MQueue->getContextImplPtr();
3006+
context_impl &ContextImpl = MQueue->getContextImpl();
30073007
device_impl &DeviceImpl = MQueue->getDeviceImpl();
30083008

30093009
// The CUDA & HIP backends don't have the equivalent of barrier
@@ -3032,7 +3032,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
30323032
false /* profilable*/
30333033
};
30343034
Adapter->call<sycl::detail::UrApiKind::urCommandBufferCreateExp>(
3035-
ContextImpl->getHandleRef(), DeviceImpl.getHandleRef(), &Desc,
3035+
ContextImpl.getHandleRef(), DeviceImpl.getHandleRef(), &Desc,
30363036
&ChildCommandBuffer);
30373037
}
30383038

@@ -3042,12 +3042,12 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
30423042
// available if a user asks for them inside the interop task scope
30433043
std::vector<interop_handle::ReqToMem> ReqToMem;
30443044
const std::vector<Requirement *> &HandlerReq = HostTask->getRequirements();
3045-
auto ReqToMemConv = [&ReqToMem, ContextImpl](Requirement *Req) {
3045+
auto ReqToMemConv = [&ReqToMem, &ContextImpl](Requirement *Req) {
30463046
const std::vector<AllocaCommandBase *> &AllocaCmds =
30473047
Req->MSYCLMemObj->MRecord->MAllocaCommands;
30483048

30493049
for (AllocaCommandBase *AllocaCmd : AllocaCmds)
3050-
if (ContextImpl.get() == getContext(AllocaCmd->getQueue())) {
3050+
if (&ContextImpl == getContext(AllocaCmd->getQueue())) {
30513051
auto MemArg =
30523052
reinterpret_cast<ur_mem_handle_t>(AllocaCmd->getMemAllocation());
30533053
ReqToMem.emplace_back(std::make_pair(Req, MemArg));
@@ -3067,8 +3067,8 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
30673067
ur_exp_command_buffer_handle_t InteropCommandBuffer =
30683068
ChildCommandBuffer ? ChildCommandBuffer : MCommandBuffer;
30693069
interop_handle IH{std::move(ReqToMem), MQueue,
3070-
DeviceImpl.shared_from_this(), ContextImpl,
3071-
InteropCommandBuffer};
3070+
DeviceImpl.shared_from_this(),
3071+
ContextImpl.shared_from_this(), InteropCommandBuffer};
30723072
CommandBufferNativeCommandData CustomOpData{
30733073
std::move(IH), HostTask->MHostTask->MInteropTask};
30743074

@@ -3470,7 +3470,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
34703470
EnqueueNativeCommandData CustomOpData{
34713471
interop_handle{std::move(ReqToMem), HostTask->MQueue,
34723472
HostTask->MQueue->getDeviceImpl().shared_from_this(),
3473-
HostTask->MQueue->getContextImplPtr()},
3473+
HostTask->MQueue->getContextImpl().shared_from_this()},
34743474
HostTask->MHostTask->MInteropTask};
34753475

34763476
ur_bool_t NativeCommandSupport = false;

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static bool IsSuitableSubReq(const Requirement *Req) {
5252
return Req->MIsSubBuffer;
5353
}
5454

55-
static bool isOnSameContext(const ContextImplPtr Context, queue_impl *Queue) {
55+
static bool isOnSameContext(context_impl *Context, queue_impl *Queue) {
5656
// Covers case for host usage (nullptr == nullptr) and existing device
5757
// contexts comparison.
5858
return Context == queue_impl::getContext(Queue);
@@ -233,8 +233,8 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
233233
"shouldn't lead to any enqueuing (no linked "
234234
"alloca or exceeding the leaf limit).");
235235
} else
236-
MemObject->MRecord.reset(new MemObjRecord{
237-
queue_impl::getContext(Queue).get(), LeafLimit, AllocateDependency});
236+
MemObject->MRecord.reset(new MemObjRecord{queue_impl::getContext(Queue),
237+
LeafLimit, AllocateDependency});
238238

239239
MMemObjs.push_back(MemObject);
240240
return MemObject->MRecord.get();
@@ -346,15 +346,16 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
346346
}
347347

348348
AllocaCommandBase *AllocaCmdSrc =
349-
findAllocaForReq(Record, Req, Record->MCurContext);
349+
findAllocaForReq(Record, Req, Record->getCurContext());
350350
if (!AllocaCmdSrc && IsSuitableSubReq(Req)) {
351351
// Since no alloca command for the sub buffer requirement was found in the
352352
// current context, need to find a parent alloca command for it (it must be
353353
// there)
354354
auto IsSuitableAlloca = [Record](AllocaCommandBase *AllocaCmd) {
355-
bool Res = isOnSameContext(Record->MCurContext, AllocaCmd->getQueue()) &&
356-
// Looking for a parent buffer alloca command
357-
AllocaCmd->getType() == Command::CommandType::ALLOCA;
355+
bool Res =
356+
isOnSameContext(Record->getCurContext(), AllocaCmd->getQueue()) &&
357+
// Looking for a parent buffer alloca command
358+
AllocaCmd->getType() == Command::CommandType::ALLOCA;
358359
return Res;
359360
};
360361
const auto It =
@@ -384,10 +385,9 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
384385
NewCmd = insertMapUnmapForLinkedCmds(AllocaCmdSrc, AllocaCmdDst, MapMode);
385386
Record->MHostAccess = MapMode;
386387
} else {
387-
388388
if ((Req->MAccessMode == access::mode::discard_write) ||
389389
(Req->MAccessMode == access::mode::discard_read_write)) {
390-
Record->MCurContext = Context;
390+
Record->setCurContext(Context);
391391
return nullptr;
392392
} else {
393393
// Full copy of buffer is needed to avoid loss of data that may be caused
@@ -409,7 +409,7 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
409409
addNodeToLeaves(Record, NewCmd, access::mode::read_write, ToEnqueue);
410410
for (Command *Cmd : ToCleanUp)
411411
cleanupCommand(Cmd);
412-
Record->MCurContext = Context;
412+
Record->setCurContext(Context);
413413
return NewCmd;
414414
}
415415

@@ -422,7 +422,8 @@ Command *Scheduler::GraphBuilder::remapMemoryObject(
422422
AllocaCommandBase *LinkedAllocaCmd = HostAllocaCmd->MLinkedAllocaCmd;
423423
assert(LinkedAllocaCmd && "Linked alloca command expected");
424424

425-
std::set<Command *> Deps = findDepsForReq(Record, Req, Record->MCurContext);
425+
std::set<Command *> Deps =
426+
findDepsForReq(Record, Req, Record->getCurContext());
426427

427428
UnMapMemObject *UnMapCmd = new UnMapMemObject(
428429
LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(),
@@ -473,7 +474,7 @@ Scheduler::GraphBuilder::addCopyBack(Requirement *Req,
473474

474475
std::set<Command *> Deps = findDepsForReq(Record, Req, nullptr);
475476
AllocaCommandBase *SrcAllocaCmd =
476-
findAllocaForReq(Record, Req, Record->MCurContext);
477+
findAllocaForReq(Record, Req, Record->getCurContext());
477478

478479
auto MemCpyCmdUniquePtr = std::make_unique<MemCpyCommandHost>(
479480
*SrcAllocaCmd->getRequirement(), SrcAllocaCmd, *Req, &Req->MData,
@@ -525,7 +526,7 @@ Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
525526
AllocaCommandBase *HostAllocaCmd =
526527
getOrCreateAllocaForReq(Record, Req, nullptr, ToEnqueue);
527528

528-
if (isOnSameContext(Record->MCurContext, HostAllocaCmd->getQueue())) {
529+
if (isOnSameContext(Record->getCurContext(), HostAllocaCmd->getQueue())) {
529530
if (!isAccessModeAllowed(Req->MAccessMode, Record->MHostAccess)) {
530531
remapMemoryObject(Record, Req,
531532
Req->MIsSubBuffer ? (static_cast<AllocaSubBufCommand *>(
@@ -571,10 +572,8 @@ Command *Scheduler::GraphBuilder::addCGUpdateHost(
571572
/// 1. New and examined commands only read -> can bypass
572573
/// 2. New and examined commands has non-overlapping requirements -> can bypass
573574
/// 3. New and examined commands have different contexts -> cannot bypass
574-
std::set<Command *>
575-
Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
576-
const Requirement *Req,
577-
const ContextImplPtr &Context) {
575+
std::set<Command *> Scheduler::GraphBuilder::findDepsForReq(
576+
MemObjRecord *Record, const Requirement *Req, context_impl *Context) {
578577
std::set<Command *> RetDeps;
579578
std::vector<Command *> Visited;
580579
const bool ReadOnlyReq = Req->MAccessMode == access::mode::read;
@@ -644,7 +643,7 @@ DepDesc Scheduler::GraphBuilder::findDepForRecord(Command *Cmd,
644643
// The function searches for the alloca command matching context and
645644
// requirement.
646645
AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
647-
MemObjRecord *Record, const Requirement *Req, const ContextImplPtr &Context,
646+
MemObjRecord *Record, const Requirement *Req, context_impl *Context,
648647
bool AllowConst) {
649648
auto IsSuitableAlloca = [&Context, Req,
650649
AllowConst](AllocaCommandBase *AllocaCmd) {
@@ -663,7 +662,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
663662
return (Record->MAllocaCommands.end() != It) ? *It : nullptr;
664663
}
665664

666-
static bool checkHostUnifiedMemory(const ContextImplPtr &Ctx) {
665+
static bool checkHostUnifiedMemory(context_impl *Ctx) {
667666
if (const char *HUMConfig = SYCLConfig<SYCL_HOST_UNIFIED_MEMORY>::get()) {
668667
if (std::strcmp(HUMConfig, "0") == 0)
669668
return Ctx == nullptr;
@@ -744,7 +743,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
744743
Record->MAllocaCommands.push_back(HostAllocaCmd);
745744
Record->MWriteLeaves.push_back(HostAllocaCmd, ToEnqueue);
746745
++(HostAllocaCmd->MLeafCounter);
747-
Record->MCurContext = nullptr;
746+
Record->setCurContext(nullptr);
748747
}
749748
}
750749
} else {
@@ -768,11 +767,12 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
768767
bool PinnedHostMemory = MemObj->usesPinnedHostMemory();
769768

770769
bool HostUnifiedMemoryOnNonHostDevice =
771-
Queue == nullptr ? checkHostUnifiedMemory(Record->MCurContext)
772-
: HostUnifiedMemory;
770+
Queue == nullptr
771+
? checkHostUnifiedMemory(Record->getCurContext())
772+
: HostUnifiedMemory;
773773
if (PinnedHostMemory || HostUnifiedMemoryOnNonHostDevice) {
774774
AllocaCommandBase *LinkedAllocaCmdCand = findAllocaForReq(
775-
Record, Req, Record->MCurContext, /*AllowConst=*/false);
775+
Record, Req, Record->getCurContext(), /*AllowConst=*/false);
776776

777777
// Cannot setup link if candidate is linked already
778778
if (LinkedAllocaCmdCand &&
@@ -812,7 +812,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
812812
AllocaCmd->MIsActive = false;
813813
} else {
814814
LinkedAllocaCmd->MIsActive = false;
815-
Record->MCurContext = Context;
815+
Record->setCurContext(Context);
816816

817817
std::set<Command *> Deps = findDepsForReq(Record, Req, Context);
818818
for (Command *Dep : Deps) {
@@ -965,7 +965,7 @@ Command *Scheduler::GraphBuilder::addCG(
965965
AllocaCmd =
966966
getOrCreateAllocaForReq(Record, Req, QueueForAlloca, ToEnqueue);
967967

968-
isSameCtx = isOnSameContext(Record->MCurContext, QueueForAlloca);
968+
isSameCtx = isOnSameContext(Record->getCurContext(), QueueForAlloca);
969969
}
970970

971971
// If there is alloca command we need to check if the latest memory is in
@@ -992,7 +992,7 @@ Command *Scheduler::GraphBuilder::addCG(
992992
const detail::CGHostTask &HT =
993993
static_cast<detail::CGHostTask &>(NewCmd->getCG());
994994

995-
if (!isOnSameContext(Record->MCurContext, HT.MQueue.get())) {
995+
if (!isOnSameContext(Record->getCurContext(), HT.MQueue.get())) {
996996
NeedMemMoveToHost = true;
997997
MemMoveTargetQueue = HT.MQueue.get();
998998
}
@@ -1226,9 +1226,7 @@ Command *Scheduler::GraphBuilder::connectDepEvent(
12261226
try {
12271227
std::shared_ptr<detail::HostTask> HT(new detail::HostTask);
12281228
std::unique_ptr<detail::CG> ConnectCG(new detail::CGHostTask(
1229-
std::move(HT),
1230-
/* Queue = */ Cmd->getQueue(),
1231-
/* Context = */ {},
1229+
std::move(HT), /* Queue = */ Cmd->getQueue(), /* Context = */ nullptr,
12321230
/* Args = */ {},
12331231
detail::CG::StorageInitHelper(
12341232
/* ArgsStorage = */ {}, /* AccStorage = */ {},
@@ -1302,7 +1300,7 @@ Command *Scheduler::GraphBuilder::addCommandGraphUpdate(
13021300

13031301
AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue, ToEnqueue);
13041302

1305-
isSameCtx = isOnSameContext(Record->MCurContext, Queue);
1303+
isSameCtx = isOnSameContext(Record->getCurContext(), Queue);
13061304
}
13071305

13081306
if (!isSameCtx) {

sycl/source/detail/scheduler/scheduler.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ class event_impl;
185185
class context_impl;
186186
class DispatchHostTask;
187187

188-
using ContextImplPtr = std::shared_ptr<detail::context_impl>;
189188
using EventImplPtr = std::shared_ptr<detail::event_impl>;
190189
using StreamImplPtr = std::shared_ptr<detail::stream_impl>;
191190

@@ -214,6 +213,10 @@ struct MemObjRecord {
214213

215214
// The context which has the latest state of the memory object.
216215
std::shared_ptr<context_impl> MCurContext;
216+
context_impl *getCurContext() { return MCurContext.get(); }
217+
void setCurContext(context_impl *Ctx) {
218+
MCurContext = Ctx ? Ctx->shared_from_this() : nullptr;
219+
}
217220

218221
// The mode this object can be accessed from the host (host_accessor).
219222
// Valid only if the current usage is on host.
@@ -688,7 +691,7 @@ class Scheduler {
688691
/// Finds dependencies for the requirement.
689692
std::set<Command *> findDepsForReq(MemObjRecord *Record,
690693
const Requirement *Req,
691-
const ContextImplPtr &Context);
694+
context_impl *Context);
692695

693696
EmptyCommand *addEmptyCmd(Command *Cmd,
694697
const std::vector<Requirement *> &Req,
@@ -702,7 +705,7 @@ class Scheduler {
702705
/// Searches for suitable alloca in memory record.
703706
AllocaCommandBase *findAllocaForReq(MemObjRecord *Record,
704707
const Requirement *Req,
705-
const ContextImplPtr &Context,
708+
context_impl *Context,
706709
bool AllowConst = true);
707710

708711
friend class Command;

0 commit comments

Comments
 (0)