Skip to content

[NFC][SYCL] Update memory_manager to pass context_impl by raw pointers #18966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions sycl/source/detail/buffer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace detail {
#ifdef XPTI_ENABLE_INSTRUMENTATION
uint8_t GBufferStreamID;
#endif
void *buffer_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
void *buffer_impl::allocateMem(context_impl *Context, bool InitFromUserData,
void *HostPtr,
ur_event_handle_t &OutEventToWait) {
bool HostPtrReadOnly = false;
Expand All @@ -30,9 +30,9 @@ void *buffer_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
"Internal error. Allocating memory on the host "
"while having use_host_ptr property");
return MemoryManager::allocateMemBuffer(
std::move(Context), this, HostPtr, HostPtrReadOnly,
BaseT::getSizeInBytes(), BaseT::MInteropEvent, BaseT::MInteropContext,
MProps, OutEventToWait);
Context, this, HostPtr, HostPtrReadOnly, BaseT::getSizeInBytes(),
BaseT::MInteropEvent, BaseT::MInteropContext.get(), MProps,
OutEventToWait);
}
void buffer_impl::constructorNotification(const detail::code_location &CodeLoc,
void *UserObj, const void *HostObj,
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/buffer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ class buffer_impl final : public SYCLMemObjT {
: BaseT(MemObject, SyclContext, OwnNativeHandle,
std::move(AvailableEvent), std::move(Allocator)) {}

void *allocateMem(ContextImplPtr Context, bool InitFromUserData,
void *HostPtr, ur_event_handle_t &OutEventToWait) override;
void *allocateMem(context_impl *Context, bool InitFromUserData, void *HostPtr,
ur_event_handle_t &OutEventToWait) override;
void constructorNotification(const detail::code_location &CodeLoc,
void *UserObj, const void *HostObj,
const void *Type, uint32_t Dim,
Expand Down
26 changes: 12 additions & 14 deletions sycl/source/detail/device_global_map_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl) {
assert(!MIsDeviceImageScopeDecorated &&
"USM allocations should not be acquired for device_global with "
"device_image_scope property.");
const std::shared_ptr<context_impl> &CtxImpl = QueueImpl.getContextImplPtr();
context_impl &CtxImpl = QueueImpl.getContextImpl();
const device_impl &DevImpl = QueueImpl.getDeviceImpl();
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);

auto DGUSMPtr = MDeviceToUSMPtrMap.find({&DevImpl, CtxImpl.get()});
auto DGUSMPtr = MDeviceToUSMPtrMap.find({&DevImpl, &CtxImpl});
if (DGUSMPtr != MDeviceToUSMPtrMap.end())
return DGUSMPtr->second;

void *NewDGUSMPtr = detail::usm::alignedAllocInternal(
0, MDeviceGlobalTSize, CtxImpl.get(), &DevImpl, sycl::usm::alloc::device);
0, MDeviceGlobalTSize, &CtxImpl, &DevImpl, sycl::usm::alloc::device);

auto NewAllocIt = MDeviceToUSMPtrMap.emplace(
std::piecewise_construct, std::forward_as_tuple(&DevImpl, CtxImpl.get()),
std::piecewise_construct, std::forward_as_tuple(&DevImpl, &CtxImpl),
std::forward_as_tuple(NewDGUSMPtr));
assert(NewAllocIt.second &&
"USM allocation for device and context already happened.");
Expand All @@ -83,7 +83,7 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(queue_impl &QueueImpl) {
NewAlloc.MInitEvent = InitEvent;
}

CtxImpl->addAssociatedDeviceGlobal(MDeviceGlobalPtr);
CtxImpl.addAssociatedDeviceGlobal(MDeviceGlobalPtr);
return NewAlloc;
}

Expand All @@ -92,22 +92,20 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
assert(!MIsDeviceImageScopeDecorated &&
"USM allocations should not be acquired for device_global with "
"device_image_scope property.");
const std::shared_ptr<context_impl> &CtxImpl = getSyclObjImpl(Context);
context_impl &CtxImpl = *getSyclObjImpl(Context);
const std::shared_ptr<device_impl> &DevImpl =
getSyclObjImpl(CtxImpl->getDevices().front());
getSyclObjImpl(CtxImpl.getDevices().front());
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);

auto DGUSMPtr = MDeviceToUSMPtrMap.find({DevImpl.get(), CtxImpl.get()});
auto DGUSMPtr = MDeviceToUSMPtrMap.find({DevImpl.get(), &CtxImpl});
if (DGUSMPtr != MDeviceToUSMPtrMap.end())
return DGUSMPtr->second;

void *NewDGUSMPtr = detail::usm::alignedAllocInternal(
0, MDeviceGlobalTSize, CtxImpl.get(), DevImpl.get(),
sycl::usm::alloc::device);
0, MDeviceGlobalTSize, &CtxImpl, DevImpl.get(), sycl::usm::alloc::device);

auto NewAllocIt = MDeviceToUSMPtrMap.emplace(
std::piecewise_construct,
std::forward_as_tuple(DevImpl.get(), CtxImpl.get()),
std::piecewise_construct, std::forward_as_tuple(DevImpl.get(), &CtxImpl),
std::forward_as_tuple(NewDGUSMPtr));
assert(NewAllocIt.second &&
"USM allocation for device and context already happened.");
Expand All @@ -123,9 +121,9 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
reinterpret_cast<const void *>(
reinterpret_cast<uintptr_t>(MDeviceGlobalPtr) +
sizeof(MDeviceGlobalPtr)),
CtxImpl, MDeviceGlobalTSize, NewAlloc.MPtr);
&CtxImpl, MDeviceGlobalTSize, NewAlloc.MPtr);

CtxImpl->addAssociatedDeviceGlobal(MDeviceGlobalPtr);
CtxImpl.addAssociatedDeviceGlobal(MDeviceGlobalPtr);
return NewAlloc;
}

Expand Down
24 changes: 12 additions & 12 deletions sycl/source/detail/image_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ image_channel_type convertChannelType(ur_image_channel_type_t Type) {
}

template <typename T>
static void getImageInfo(const ContextImplPtr &Context, ur_image_info_t Info,
T &Dest, ur_mem_handle_t InteropMemObject) {
const AdapterPtr &Adapter = Context->getAdapter();
static void getImageInfo(context_impl &Context, ur_image_info_t Info, T &Dest,
ur_mem_handle_t InteropMemObject) {
const AdapterPtr &Adapter = Context.getAdapter();
Adapter->call<UrApiKind::urMemImageGetInfo>(InteropMemObject, Info, sizeof(T),
&Dest, nullptr);
}
Expand All @@ -274,8 +274,8 @@ image_impl::image_impl(cl_mem MemObject, const context &SyclContext,
std::move(Allocator)),
MDimensions(Dimensions), MRange({0, 0, 0}) {
ur_mem_handle_t Mem = ur::cast<ur_mem_handle_t>(BaseT::MInteropMemObject);
const ContextImplPtr &Context = getSyclObjImpl(SyclContext);
const AdapterPtr &Adapter = Context->getAdapter();
detail::context_impl &Context = *getSyclObjImpl(SyclContext);
const AdapterPtr &Adapter = Context.getAdapter();
Adapter->call<UrApiKind::urMemGetInfo>(Mem, UR_MEM_INFO_SIZE, sizeof(size_t),
&(BaseT::MSizeInBytes), nullptr);

Expand Down Expand Up @@ -323,7 +323,7 @@ image_impl::image_impl(ur_native_handle_t MemObject, const context &SyclContext,
setPitches(); // sets MRowPitch, MSlice and BaseT::MSizeInBytes
}

void *image_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
void *image_impl::allocateMem(context_impl *Context, bool InitFromUserData,
void *HostPtr,
ur_event_handle_t &OutEventToWait) {
bool HostPtrReadOnly = false;
Expand All @@ -338,13 +338,13 @@ void *image_impl::allocateMem(ContextImplPtr Context, bool InitFromUserData,
"The check an image format failed.");

return MemoryManager::allocateMemImage(
std::move(Context), this, HostPtr, HostPtrReadOnly,
BaseT::getSizeInBytes(), Desc, Format, BaseT::MInteropEvent,
BaseT::MInteropContext, MProps, OutEventToWait);
Context, this, HostPtr, HostPtrReadOnly, BaseT::getSizeInBytes(), Desc,
Format, BaseT::MInteropEvent, BaseT::MInteropContext.get(), MProps,
OutEventToWait);
}

bool image_impl::checkImageDesc(const ur_image_desc_t &Desc,
ContextImplPtr Context, void *UserPtr) {
context_impl *Context, void *UserPtr) {
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE1D, UR_MEM_TYPE_IMAGE1D_ARRAY,
UR_MEM_TYPE_IMAGE2D_ARRAY, UR_MEM_TYPE_IMAGE2D) &&
!checkImageValueRange<info::device::image2d_max_width>(
Expand Down Expand Up @@ -409,7 +409,7 @@ bool image_impl::checkImageDesc(const ur_image_desc_t &Desc,
}

bool image_impl::checkImageFormat(const ur_image_format_t &Format,
ContextImplPtr Context) {
context_impl *Context) {
(void)Context;
if (checkAny(Format.channelOrder, UR_IMAGE_CHANNEL_ORDER_INTENSITY,
UR_IMAGE_CHANNEL_ORDER_LUMINANCE) &&
Expand Down Expand Up @@ -451,7 +451,7 @@ bool image_impl::checkImageFormat(const ur_image_format_t &Format,
return true;
}

std::vector<device> image_impl::getDevices(const ContextImplPtr Context) {
std::vector<device> image_impl::getDevices(context_impl *Context) {
if (!Context)
return {};
return Context->get_info<info::context::devices>();
Expand Down
11 changes: 5 additions & 6 deletions sycl/source/detail/image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ class image_impl final : public SYCLMemObjT {
std::abort();
}

void *allocateMem(ContextImplPtr Context, bool InitFromUserData,
void *HostPtr, ur_event_handle_t &OutEventToWait) override;
void *allocateMem(context_impl *Context, bool InitFromUserData, void *HostPtr,
ur_event_handle_t &OutEventToWait) override;

MemObjType getType() const override { return MemObjType::Image; }

Expand Down Expand Up @@ -298,7 +298,7 @@ class image_impl final : public SYCLMemObjT {
void unsampledImageDestructorNotification(void *UserObj);

private:
std::vector<device> getDevices(const ContextImplPtr Context);
std::vector<device> getDevices(context_impl *Context);

ur_mem_type_t getImageType() {
if (MDimensions == 1)
Expand Down Expand Up @@ -330,7 +330,7 @@ class image_impl final : public SYCLMemObjT {
return Desc;
}

bool checkImageDesc(const ur_image_desc_t &Desc, ContextImplPtr Context,
bool checkImageDesc(const ur_image_desc_t &Desc, context_impl *Context,
void *UserPtr);

ur_image_format_t getImageFormat() {
Expand All @@ -340,8 +340,7 @@ class image_impl final : public SYCLMemObjT {
return Format;
}

bool checkImageFormat(const ur_image_format_t &Format,
ContextImplPtr Context);
bool checkImageFormat(const ur_image_format_t &Format, context_impl *Context);

uint8_t MDimensions = 0;
bool MIsArrayImage = false;
Expand Down
Loading
Loading