Skip to content

[NFC][SYCL] Use plain context_impl * for func params in scheduler #19156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/detail/queue_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ class queue_impl : public std::enable_shared_from_this<queue_impl> {
// for in order ones.
void revisitUnenqueuedCommandsState(const EventImplPtr &CompletedHostTask);

static ContextImplPtr getContext(queue_impl *Queue) {
return Queue ? Queue->getContextImplPtr() : nullptr;
static context_impl *getContext(queue_impl *Queue) {
return Queue ? &Queue->getContextImpl() : nullptr;
}

// Must be called under MMutex protection
Expand Down
26 changes: 13 additions & 13 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ class DispatchHostTask {
"Host task submissions should have an associated queue");
interop_handle IH{MReqToMem, HostTask.MQueue,
HostTask.MQueue->getDeviceImpl().shared_from_this(),
HostTask.MQueue->getContextImplPtr()};
HostTask.MQueue->getContextImpl().shared_from_this()};
// TODO: should all the backends that support this entry point use this
// for host task?
auto &Queue = HostTask.MQueue;
Expand Down Expand Up @@ -2677,7 +2677,7 @@ void enqueueImpKernel(
detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
bool KernelHasSpecialCaptures) {
// Run OpenCL kernel
auto &ContextImpl = Queue.getContextImplPtr();
context_impl &ContextImpl = Queue.getContextImpl();
device_impl &DeviceImpl = Queue.getDeviceImpl();
ur_kernel_handle_t Kernel = nullptr;
std::mutex *KernelMutex = nullptr;
Expand Down Expand Up @@ -2715,7 +2715,7 @@ void enqueueImpKernel(
KernelMutex = SyclKernelImpl->getCacheMutex();
} else {
KernelCacheVal = detail::ProgramManager::getInstance().getOrCreateKernel(
*ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc);
ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc);
Kernel = KernelCacheVal->MKernelHandle;
KernelMutex = KernelCacheVal->MMutex;
Program = KernelCacheVal->MProgramHandle;
Expand All @@ -2727,7 +2727,7 @@ void enqueueImpKernel(

// Initialize device globals associated with this.
std::vector<ur_event_handle_t> DeviceGlobalInitEvents =
ContextImpl->initializeDeviceGlobals(Program, Queue);
ContextImpl.initializeDeviceGlobals(Program, Queue);
if (!DeviceGlobalInitEvents.empty()) {
std::vector<ur_event_handle_t> EventsWithDeviceGlobalInits;
EventsWithDeviceGlobalInits.reserve(RawEvents.size() +
Expand Down Expand Up @@ -2784,9 +2784,9 @@ ur_result_t enqueueReadWriteHostPipe(queue_impl &Queue,

ur_program_handle_t Program = nullptr;
device Device = Queue.get_device();
ContextImplPtr ContextImpl = Queue.getContextImplPtr();
context_impl &ContextImpl = Queue.getContextImpl();
std::optional<ur_program_handle_t> CachedProgram =
ContextImpl->getProgramForHostPipe(Device, hostPipeEntry);
ContextImpl.getProgramForHostPipe(Device, hostPipeEntry);
if (CachedProgram)
Program = *CachedProgram;
else {
Expand Down Expand Up @@ -3004,7 +3004,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
// Queue is created by graph_impl before creating command to submit to
// scheduler.
const AdapterPtr &Adapter = MQueue->getAdapter();
auto ContextImpl = MQueue->getContextImplPtr();
context_impl &ContextImpl = MQueue->getContextImpl();
device_impl &DeviceImpl = MQueue->getDeviceImpl();

// The CUDA & HIP backends don't have the equivalent of barrier
Expand Down Expand Up @@ -3033,7 +3033,7 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
false /* profilable*/
};
Adapter->call<sycl::detail::UrApiKind::urCommandBufferCreateExp>(
ContextImpl->getHandleRef(), DeviceImpl.getHandleRef(), &Desc,
ContextImpl.getHandleRef(), DeviceImpl.getHandleRef(), &Desc,
&ChildCommandBuffer);
}

Expand All @@ -3043,12 +3043,12 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
// available if a user asks for them inside the interop task scope
std::vector<interop_handle::ReqToMem> ReqToMem;
const std::vector<Requirement *> &HandlerReq = HostTask->getRequirements();
auto ReqToMemConv = [&ReqToMem, ContextImpl](Requirement *Req) {
auto ReqToMemConv = [&ReqToMem, &ContextImpl](Requirement *Req) {
const std::vector<AllocaCommandBase *> &AllocaCmds =
Req->MSYCLMemObj->MRecord->MAllocaCommands;

for (AllocaCommandBase *AllocaCmd : AllocaCmds)
if (ContextImpl.get() == getContext(AllocaCmd->getQueue())) {
if (&ContextImpl == getContext(AllocaCmd->getQueue())) {
auto MemArg =
reinterpret_cast<ur_mem_handle_t>(AllocaCmd->getMemAllocation());
ReqToMem.emplace_back(std::make_pair(Req, MemArg));
Expand All @@ -3068,8 +3068,8 @@ ur_result_t ExecCGCommand::enqueueImpCommandBuffer() {
ur_exp_command_buffer_handle_t InteropCommandBuffer =
ChildCommandBuffer ? ChildCommandBuffer : MCommandBuffer;
interop_handle IH{std::move(ReqToMem), MQueue,
DeviceImpl.shared_from_this(), ContextImpl,
InteropCommandBuffer};
DeviceImpl.shared_from_this(),
ContextImpl.shared_from_this(), InteropCommandBuffer};
CommandBufferNativeCommandData CustomOpData{
std::move(IH), HostTask->MHostTask->MInteropTask};

Expand Down Expand Up @@ -3471,7 +3471,7 @@ ur_result_t ExecCGCommand::enqueueImpQueue() {
EnqueueNativeCommandData CustomOpData{
interop_handle{std::move(ReqToMem), HostTask->MQueue,
HostTask->MQueue->getDeviceImpl().shared_from_this(),
HostTask->MQueue->getContextImplPtr()},
HostTask->MQueue->getContextImpl().shared_from_this()},
HostTask->MHostTask->MInteropTask};

ur_bool_t NativeCommandSupport = false;
Expand Down
58 changes: 28 additions & 30 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static bool IsSuitableSubReq(const Requirement *Req) {
return Req->MIsSubBuffer;
}

static bool isOnSameContext(const ContextImplPtr Context, queue_impl *Queue) {
static bool isOnSameContext(context_impl *Context, queue_impl *Queue) {
// Covers case for host usage (nullptr == nullptr) and existing device
// contexts comparison.
return Context == queue_impl::getContext(Queue);
Expand Down Expand Up @@ -233,8 +233,8 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
"shouldn't lead to any enqueuing (no linked "
"alloca or exceeding the leaf limit).");
} else
MemObject->MRecord.reset(new MemObjRecord{
queue_impl::getContext(Queue).get(), LeafLimit, AllocateDependency});
MemObject->MRecord.reset(new MemObjRecord{queue_impl::getContext(Queue),
LeafLimit, AllocateDependency});

MMemObjs.push_back(MemObject);
return MemObject->MRecord.get();
Expand Down Expand Up @@ -346,15 +346,16 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
}

AllocaCommandBase *AllocaCmdSrc =
findAllocaForReq(Record, Req, Record->MCurContext);
findAllocaForReq(Record, Req, Record->getCurContext());
if (!AllocaCmdSrc && IsSuitableSubReq(Req)) {
// Since no alloca command for the sub buffer requirement was found in the
// current context, need to find a parent alloca command for it (it must be
// there)
auto IsSuitableAlloca = [Record](AllocaCommandBase *AllocaCmd) {
bool Res = isOnSameContext(Record->MCurContext, AllocaCmd->getQueue()) &&
// Looking for a parent buffer alloca command
AllocaCmd->getType() == Command::CommandType::ALLOCA;
bool Res =
isOnSameContext(Record->getCurContext(), AllocaCmd->getQueue()) &&
// Looking for a parent buffer alloca command
AllocaCmd->getType() == Command::CommandType::ALLOCA;
return Res;
};
const auto It =
Expand Down Expand Up @@ -384,10 +385,9 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
NewCmd = insertMapUnmapForLinkedCmds(AllocaCmdSrc, AllocaCmdDst, MapMode);
Record->MHostAccess = MapMode;
} else {

if ((Req->MAccessMode == access::mode::discard_write) ||
(Req->MAccessMode == access::mode::discard_read_write)) {
Record->MCurContext = Context;
Record->setCurContext(Context);
return nullptr;
} else {
// Full copy of buffer is needed to avoid loss of data that may be caused
Expand All @@ -409,7 +409,7 @@ Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
addNodeToLeaves(Record, NewCmd, access::mode::read_write, ToEnqueue);
for (Command *Cmd : ToCleanUp)
cleanupCommand(Cmd);
Record->MCurContext = Context;
Record->setCurContext(Context);
return NewCmd;
}

Expand All @@ -422,7 +422,8 @@ Command *Scheduler::GraphBuilder::remapMemoryObject(
AllocaCommandBase *LinkedAllocaCmd = HostAllocaCmd->MLinkedAllocaCmd;
assert(LinkedAllocaCmd && "Linked alloca command expected");

std::set<Command *> Deps = findDepsForReq(Record, Req, Record->MCurContext);
std::set<Command *> Deps =
findDepsForReq(Record, Req, Record->getCurContext());

UnMapMemObject *UnMapCmd = new UnMapMemObject(
LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement(),
Expand Down Expand Up @@ -473,7 +474,7 @@ Scheduler::GraphBuilder::addCopyBack(Requirement *Req,

std::set<Command *> Deps = findDepsForReq(Record, Req, nullptr);
AllocaCommandBase *SrcAllocaCmd =
findAllocaForReq(Record, Req, Record->MCurContext);
findAllocaForReq(Record, Req, Record->getCurContext());

auto MemCpyCmdUniquePtr = std::make_unique<MemCpyCommandHost>(
*SrcAllocaCmd->getRequirement(), SrcAllocaCmd, *Req, &Req->MData,
Expand Down Expand Up @@ -525,7 +526,7 @@ Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
AllocaCommandBase *HostAllocaCmd =
getOrCreateAllocaForReq(Record, Req, nullptr, ToEnqueue);

if (isOnSameContext(Record->MCurContext, HostAllocaCmd->getQueue())) {
if (isOnSameContext(Record->getCurContext(), HostAllocaCmd->getQueue())) {
if (!isAccessModeAllowed(Req->MAccessMode, Record->MHostAccess)) {
remapMemoryObject(Record, Req,
Req->MIsSubBuffer ? (static_cast<AllocaSubBufCommand *>(
Expand Down Expand Up @@ -571,10 +572,8 @@ Command *Scheduler::GraphBuilder::addCGUpdateHost(
/// 1. New and examined commands only read -> can bypass
/// 2. New and examined commands has non-overlapping requirements -> can bypass
/// 3. New and examined commands have different contexts -> cannot bypass
std::set<Command *>
Scheduler::GraphBuilder::findDepsForReq(MemObjRecord *Record,
const Requirement *Req,
const ContextImplPtr &Context) {
std::set<Command *> Scheduler::GraphBuilder::findDepsForReq(
MemObjRecord *Record, const Requirement *Req, context_impl *Context) {
std::set<Command *> RetDeps;
std::vector<Command *> Visited;
const bool ReadOnlyReq = Req->MAccessMode == access::mode::read;
Expand Down Expand Up @@ -644,7 +643,7 @@ DepDesc Scheduler::GraphBuilder::findDepForRecord(Command *Cmd,
// The function searches for the alloca command matching context and
// requirement.
AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
MemObjRecord *Record, const Requirement *Req, const ContextImplPtr &Context,
MemObjRecord *Record, const Requirement *Req, context_impl *Context,
bool AllowConst) {
auto IsSuitableAlloca = [&Context, Req,
AllowConst](AllocaCommandBase *AllocaCmd) {
Expand All @@ -663,7 +662,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::findAllocaForReq(
return (Record->MAllocaCommands.end() != It) ? *It : nullptr;
}

static bool checkHostUnifiedMemory(const ContextImplPtr &Ctx) {
static bool checkHostUnifiedMemory(context_impl *Ctx) {
if (const char *HUMConfig = SYCLConfig<SYCL_HOST_UNIFIED_MEMORY>::get()) {
if (std::strcmp(HUMConfig, "0") == 0)
return Ctx == nullptr;
Expand Down Expand Up @@ -744,7 +743,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
Record->MAllocaCommands.push_back(HostAllocaCmd);
Record->MWriteLeaves.push_back(HostAllocaCmd, ToEnqueue);
++(HostAllocaCmd->MLeafCounter);
Record->MCurContext = nullptr;
Record->setCurContext(nullptr);
}
}
} else {
Expand All @@ -768,11 +767,12 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
bool PinnedHostMemory = MemObj->usesPinnedHostMemory();

bool HostUnifiedMemoryOnNonHostDevice =
Queue == nullptr ? checkHostUnifiedMemory(Record->MCurContext)
: HostUnifiedMemory;
Queue == nullptr
? checkHostUnifiedMemory(Record->getCurContext())
: HostUnifiedMemory;
if (PinnedHostMemory || HostUnifiedMemoryOnNonHostDevice) {
AllocaCommandBase *LinkedAllocaCmdCand = findAllocaForReq(
Record, Req, Record->MCurContext, /*AllowConst=*/false);
Record, Req, Record->getCurContext(), /*AllowConst=*/false);

// Cannot setup link if candidate is linked already
if (LinkedAllocaCmdCand &&
Expand Down Expand Up @@ -812,7 +812,7 @@ AllocaCommandBase *Scheduler::GraphBuilder::getOrCreateAllocaForReq(
AllocaCmd->MIsActive = false;
} else {
LinkedAllocaCmd->MIsActive = false;
Record->MCurContext = Context;
Record->setCurContext(Context);

std::set<Command *> Deps = findDepsForReq(Record, Req, Context);
for (Command *Dep : Deps) {
Expand Down Expand Up @@ -965,7 +965,7 @@ Command *Scheduler::GraphBuilder::addCG(
AllocaCmd =
getOrCreateAllocaForReq(Record, Req, QueueForAlloca, ToEnqueue);

isSameCtx = isOnSameContext(Record->MCurContext, QueueForAlloca);
isSameCtx = isOnSameContext(Record->getCurContext(), QueueForAlloca);
}

// If there is alloca command we need to check if the latest memory is in
Expand All @@ -992,7 +992,7 @@ Command *Scheduler::GraphBuilder::addCG(
const detail::CGHostTask &HT =
static_cast<detail::CGHostTask &>(NewCmd->getCG());

if (!isOnSameContext(Record->MCurContext, HT.MQueue.get())) {
if (!isOnSameContext(Record->getCurContext(), HT.MQueue.get())) {
NeedMemMoveToHost = true;
MemMoveTargetQueue = HT.MQueue.get();
}
Expand Down Expand Up @@ -1226,9 +1226,7 @@ Command *Scheduler::GraphBuilder::connectDepEvent(
try {
std::shared_ptr<detail::HostTask> HT(new detail::HostTask);
std::unique_ptr<detail::CG> ConnectCG(new detail::CGHostTask(
std::move(HT),
/* Queue = */ Cmd->getQueue(),
/* Context = */ {},
std::move(HT), /* Queue = */ Cmd->getQueue(), /* Context = */ nullptr,
/* Args = */ {},
detail::CG::StorageInitHelper(
/* ArgsStorage = */ {}, /* AccStorage = */ {},
Expand Down Expand Up @@ -1302,7 +1300,7 @@ Command *Scheduler::GraphBuilder::addCommandGraphUpdate(

AllocaCmd = getOrCreateAllocaForReq(Record, Req, Queue, ToEnqueue);

isSameCtx = isOnSameContext(Record->MCurContext, Queue);
isSameCtx = isOnSameContext(Record->getCurContext(), Queue);
}

if (!isSameCtx) {
Expand Down
9 changes: 6 additions & 3 deletions sycl/source/detail/scheduler/scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ class event_impl;
class context_impl;
class DispatchHostTask;

using ContextImplPtr = std::shared_ptr<detail::context_impl>;
using EventImplPtr = std::shared_ptr<detail::event_impl>;
using StreamImplPtr = std::shared_ptr<detail::stream_impl>;

Expand Down Expand Up @@ -214,6 +213,10 @@ struct MemObjRecord {

// The context which has the latest state of the memory object.
std::shared_ptr<context_impl> MCurContext;
context_impl *getCurContext() { return MCurContext.get(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could there be a benefit to having a const overload?

Suggested change
context_impl *getCurContext() { return MCurContext.get(); }
context_impl *getCurContext() { return MCurContext.get(); }
const context_impl *getCurContext() const { return MCurContext.get(); }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const-correctness and the choices we have to make (technically, C++ will allow to return non-const from const method) are intentionally left for a future refactoring separate from this activity.

void setCurContext(context_impl *Ctx) {
MCurContext = Ctx ? Ctx->shared_from_this() : nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that the shared_from_this functionality could be used by the shared_ptr ctor, so

Suggested change
MCurContext = Ctx ? Ctx->shared_from_this() : nullptr;
MCurContext = std::shared_ptr<context_impl>{Ctx};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That only works during creation of the very first shared_ptr that has been done in context_impl::create/std::make_shared, AFAIK. https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2016/p0033r1.html might be helpful reading on that subject as well.

}

// The mode this object can be accessed from the host (host_accessor).
// Valid only if the current usage is on host.
Expand Down Expand Up @@ -688,7 +691,7 @@ class Scheduler {
/// Finds dependencies for the requirement.
std::set<Command *> findDepsForReq(MemObjRecord *Record,
const Requirement *Req,
const ContextImplPtr &Context);
context_impl *Context);

EmptyCommand *addEmptyCmd(Command *Cmd,
const std::vector<Requirement *> &Req,
Expand All @@ -702,7 +705,7 @@ class Scheduler {
/// Searches for suitable alloca in memory record.
AllocaCommandBase *findAllocaForReq(MemObjRecord *Record,
const Requirement *Req,
const ContextImplPtr &Context,
context_impl *Context,
bool AllowConst = true);

friend class Command;
Expand Down